From dba931fa7fc276fbecfdaff44cd9bc32d2dca98a Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Thu, 19 Sep 2024 11:11:03 -0700 Subject: [PATCH 01/35] [CHORE] Change TPC-H q4 and q22 answers to use new join types (#2756) Also adds some functionality to the benchmarking code for manual runs Blocked by #2743 for compatibility with new executor --- benchmarking/tpch/__main__.py | 26 ++++++++++++++++++-------- benchmarking/tpch/answers.py | 11 ++++------- benchmarking/tpch/ray_job_runner.py | 2 +- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/benchmarking/tpch/__main__.py b/benchmarking/tpch/__main__.py index 6a7d24b290..33cc847c9c 100644 --- a/benchmarking/tpch/__main__.py +++ b/benchmarking/tpch/__main__.py @@ -123,7 +123,7 @@ def _get_df(table_name: str) -> DataFrame: def run_all_benchmarks( parquet_folder: str, - skip_questions: set[int], + questions: list[int], csv_output_location: str | None, ray_job_dashboard_url: str | None = None, requirements: str | None = None, @@ -133,11 +133,7 @@ def run_all_benchmarks( daft_context = get_context() metrics_builder = MetricsBuilder(daft_context.runner_config.name) - for i in range(1, 23): - if i in skip_questions: - logger.warning("Skipping TPC-H q%s", i) - continue - + for i in questions: # Run as a Ray Job if dashboard URL is provided if ray_job_dashboard_url is not None: from benchmarking.tpch import ray_job_runner @@ -202,7 +198,10 @@ def get_ray_runtime_env(requirements: str | None) -> dict: runtime_env = { "py_modules": [daft], "eager_install": True, - "env_vars": {"DAFT_PROGRESS_BAR": "0"}, + "env_vars": { + "DAFT_PROGRESS_BAR": "0", + "DAFT_RUNNER": "ray", + }, } if requirements: runtime_env.update({"pip": requirements}) @@ -266,6 +265,7 @@ def warm_up_function(): parser.add_argument( "--num_parts", default=None, help="Number of parts to generate (defaults to 1 part per GB)", type=int ) + parser.add_argument("--questions", type=str, default=None, help="Comma-separated list of questions to run") parser.add_argument("--skip_questions", type=str, default=None, help="Comma-separated list of questions to skip") parser.add_argument("--output_csv", default=None, type=str, help="Location to output CSV file") parser.add_argument( @@ -310,9 +310,19 @@ def warm_up_function(): else: warmup_environment(args.requirements, parquet_folder) + if args.skip_questions is not None: + if args.questions is not None: + raise ValueError("Cannot specify both --questions and --skip_questions") + skip_questions = {int(s) for s in args.skip_questions.split(",")} + questions = [q for q in range(1, MetricsBuilder.NUM_TPCH_QUESTIONS + 1) if q not in skip_questions] + elif args.questions is not None: + questions = sorted(set(int(s) for s in args.questions.split(","))) + else: + questions = list(range(1, MetricsBuilder.NUM_TPCH_QUESTIONS + 1)) + run_all_benchmarks( parquet_folder, - skip_questions={int(s) for s in args.skip_questions.split(",")} if args.skip_questions is not None else set(), + questions=questions, csv_output_location=args.output_csv, ray_job_dashboard_url=args.ray_job_dashboard_url, requirements=args.requirements, diff --git a/benchmarking/tpch/answers.py b/benchmarking/tpch/answers.py index 90e8b212c7..3c7577f665 100644 --- a/benchmarking/tpch/answers.py +++ b/benchmarking/tpch/answers.py @@ -106,12 +106,12 @@ def q4(get_df: GetDFFunc) -> DataFrame: (col("O_ORDERDATE") >= datetime.date(1993, 7, 1)) & (col("O_ORDERDATE") < datetime.date(1993, 10, 1)) ) - lineitems = lineitems.where(col("L_COMMITDATE") < col("L_RECEIPTDATE")).select(col("L_ORDERKEY")).distinct() + lineitems = lineitems.where(col("L_COMMITDATE") < col("L_RECEIPTDATE")) daft_df = ( - lineitems.join(orders, left_on=col("L_ORDERKEY"), right_on=col("O_ORDERKEY")) + orders.join(lineitems, left_on=col("O_ORDERKEY"), right_on=col("L_ORDERKEY"), how="semi") .groupby(col("O_ORDERPRIORITY")) - .agg(col("L_ORDERKEY").count().alias("order_count")) + .agg(col("O_ORDERKEY").count().alias("order_count")) .sort(col("O_ORDERPRIORITY")) ) return daft_df @@ -660,11 +660,8 @@ def q22(get_df: GetDFFunc) -> DataFrame: res_1.where(col("C_ACCTBAL") > 0).agg(col("C_ACCTBAL").mean().alias("avg_acctbal")).with_column("lit", lit(1)) ) - res_3 = orders.select("O_CUSTKEY") - daft_df = ( - res_1.join(res_3, left_on="C_CUSTKEY", right_on="O_CUSTKEY", how="left") - .where(col("O_CUSTKEY").is_null()) + res_1.join(orders, left_on="C_CUSTKEY", right_on="O_CUSTKEY", how="anti") .with_column("lit", lit(1)) .join(res_2, on="lit") .where(col("C_ACCTBAL") > col("avg_acctbal")) diff --git a/benchmarking/tpch/ray_job_runner.py b/benchmarking/tpch/ray_job_runner.py index e6163e8fe1..42fcfc96cf 100644 --- a/benchmarking/tpch/ray_job_runner.py +++ b/benchmarking/tpch/ray_job_runner.py @@ -57,7 +57,7 @@ def ray_job_params( ) -> dict: return dict( submission_id=f"tpch-q{tpch_qnum}-{str(uuid.uuid4())[:4]}", - entrypoint=f"python {str(entrypoint.relative_to(working_dir))} --parquet-folder {parquet_folder_path} --question-number {tpch_qnum}", + entrypoint=f"python3 {str(entrypoint.relative_to(working_dir))} --parquet-folder {parquet_folder_path} --question-number {tpch_qnum}", runtime_env={ "working_dir": str(working_dir), **runtime_env, From 78a92a2edefd4b4080dbdb1597c988fd61b66243 Mon Sep 17 00:00:00 2001 From: Desmond Cheong Date: Thu, 19 Sep 2024 13:53:48 -0700 Subject: [PATCH 02/35] [PERF] Lazily import heavy modules to speed up import times (#2826) Introduce lazy imports for heavy modules that are not needed as top-level imports. For example, `ray` does not need to be a top level import (it should only be imported when using the ray runner or when specific ray data extension types needed. Another example would be `UnityCatalogTable`, which is a relatively heavy import despite only being needed when using delta lake. Modules to import lazily were determined by the proportion of import time as shown by `importtime-output-wrapper -c 'import daft' --format waterfall --depth 25`. The list of newly lazily imported modules are: - `daft.unity_catalog` - `fsspec` - `numpy` - `pandas` - `PIL.Image` - `pyarrow` - `pyarrow.csv` - `pyarrow.dataset` - `pyarrow.fs` - `pyarrow.json` - `pyarrow.parquet` - `ray` - `ray.data.extensions` - `xml.etree.ElementTree` Uses https://github.com/Eventual-Inc/Daft/pull/2836 in order to defer the import of `pyarrow`. Additionally, we move all type-checking-only module imports into type checking blocks. With these changes, import times go from roughly 0.6-0.7s to ~0.045s (~13-15x faster). --------- Co-authored-by: Sammy Sidhu --- daft/.ruff.toml | 13 ++++ daft/arrow_utils.py | 30 +++++---- daft/daft/__init__.pyi | 14 ++-- daft/dataframe/preview.py | 4 +- daft/datatype.py | 53 +++++++++------ daft/delta_lake/delta_lake_scan.py | 5 +- daft/dependencies.py | 32 ++++++++++ daft/execution/execution_step.py | 11 ++-- daft/execution/native_executor.py | 10 +-- daft/execution/physical_plan.py | 9 ++- daft/execution/rust_physical_plan_shim.py | 3 +- daft/expressions/expressions.py | 6 +- daft/filesystem.py | 36 +++++------ daft/hudi/hudi_scan.py | 5 +- daft/hudi/pyhudi/filegroup.py | 7 +- daft/hudi/pyhudi/table.py | 4 +- daft/hudi/pyhudi/timeline.py | 5 +- daft/hudi/pyhudi/utils.py | 5 +- daft/iceberg/iceberg_scan.py | 16 +++-- .../schema_field_id_mapping_visitor.py | 6 +- daft/internal/gpu.py | 3 +- daft/io/_deltalake.py | 10 ++- daft/io/common.py | 3 +- daft/io/object_store_options.py | 4 +- daft/io/scan.py | 8 ++- daft/lazy_import.py | 62 ++++++++++++++++++ daft/logical/builder.py | 5 +- daft/logical/map_partition_ops.py | 5 +- .../plan_scheduler/physical_plan_scheduler.py | 14 ++-- daft/runners/partitioning.py | 10 +-- daft/runners/progress_bar.py | 5 +- daft/runners/pyrunner.py | 12 ++-- daft/runners/ray_runner.py | 25 +++----- daft/runners/runner.py | 14 ++-- daft/runners/runner_io.py | 4 +- daft/series.py | 25 ++------ daft/sql/sql_connection.py | 4 +- daft/sql/sql_scan.py | 13 ++-- daft/table/micropartition.py | 11 +--- daft/table/schema_inference.py | 7 +- daft/table/table.py | 17 +---- daft/table/table_io.py | 43 +++++++------ daft/udf.py | 23 ++----- daft/unity_catalog/__init__.py | 4 +- daft/utils.py | 8 ++- daft/viz/dataframe_display.py | 16 ++--- daft/viz/html_viz_hooks.py | 64 ++++++++----------- tests/dataframe/test_logical_type.py | 3 +- tests/series/test_embedding.py | 4 +- tests/series/test_image.py | 4 +- tests/series/test_tensor.py | 3 +- tests/table/table_io/test_csv.py | 5 +- 52 files changed, 401 insertions(+), 311 deletions(-) create mode 100644 daft/.ruff.toml create mode 100644 daft/dependencies.py create mode 100644 daft/lazy_import.py diff --git a/daft/.ruff.toml b/daft/.ruff.toml new file mode 100644 index 0000000000..1acbd4af4b --- /dev/null +++ b/daft/.ruff.toml @@ -0,0 +1,13 @@ +extend = "../.ruff.toml" + +[lint] +extend-select = [ + "TID253", # banned-module-level-imports, derived from flake8-tidy-imports + "TCH" # flake8-type-checking +] + +[lint.flake8-tidy-imports] +# Ban certain modules from being imported at module level, instead requiring +# that they're imported lazily (e.g., within a function definition, +# with daft.lazy_import.LazyImport, or with TYPE_CHECKING). +banned-module-level-imports = ["daft.unity_catalog", "fsspec", "numpy", "pandas", "PIL", "pyarrow", "ray", "xml"] diff --git a/daft/arrow_utils.py b/daft/arrow_utils.py index 8197c86804..1f211653f2 100644 --- a/daft/arrow_utils.py +++ b/daft/arrow_utils.py @@ -2,7 +2,7 @@ import sys -import pyarrow as pa +from daft.dependencies import pa def ensure_array(arr: pa.Array) -> pa.Array: @@ -34,13 +34,21 @@ class _FixEmptyStructArrays: Python layer before going through ffi into Rust. """ - EMPTY_STRUCT_TYPE = pa.struct([]) - SINGLE_FIELD_STRUCT_TYPE = pa.struct({"": pa.null()}) - SINGLE_FIELD_STRUCT_VALUE = {"": None} + @staticmethod + def get_empty_struct_type(): + return pa.struct([]) + + @staticmethod + def get_single_field_struct_type(): + return pa.struct({"": pa.null()}) + + @staticmethod + def get_single_field_struct_value(): + return {"": None} def ensure_table(table: pa.Table) -> pa.Table: empty_struct_fields = [ - (i, f) for (i, f) in enumerate(table.schema) if f.type == _FixEmptyStructArrays.EMPTY_STRUCT_TYPE + (i, f) for (i, f) in enumerate(table.schema) if f.type == _FixEmptyStructArrays.get_empty_struct_type() ] if not empty_struct_fields: return table @@ -49,19 +57,19 @@ def ensure_table(table: pa.Table) -> pa.Table: return table def ensure_chunked_array(arr: pa.ChunkedArray) -> pa.ChunkedArray: - if arr.type != _FixEmptyStructArrays.EMPTY_STRUCT_TYPE: + if arr.type != _FixEmptyStructArrays.get_empty_struct_type(): return arr return pa.chunked_array([_FixEmptyStructArrays.ensure_array(chunk) for chunk in arr.chunks]) def ensure_array(arr: pa.Array) -> pa.Array: """Recursively converts empty struct arrays to single-field struct arrays""" - if arr.type == _FixEmptyStructArrays.EMPTY_STRUCT_TYPE: + if arr.type == _FixEmptyStructArrays.get_empty_struct_type(): return pa.array( [ - _FixEmptyStructArrays.SINGLE_FIELD_STRUCT_VALUE if valid.as_py() else None + _FixEmptyStructArrays.get_single_field_struct_value() if valid.as_py() else None for valid in arr.is_valid() ], - type=_FixEmptyStructArrays.SINGLE_FIELD_STRUCT_TYPE, + type=_FixEmptyStructArrays.get_single_field_struct_type(), ) elif isinstance(arr, pa.StructArray): @@ -77,10 +85,10 @@ def ensure_array(arr: pa.Array) -> pa.Array: def remove_empty_struct_placeholders(arr: pa.Array): """Recursively removes the empty struct placeholders placed by _FixEmptyStructArrays.ensure_array""" - if arr.type == _FixEmptyStructArrays.SINGLE_FIELD_STRUCT_TYPE: + if arr.type == _FixEmptyStructArrays.get_single_field_struct_type(): return pa.array( [{} if valid.as_py() else None for valid in arr.is_valid()], - type=_FixEmptyStructArrays.EMPTY_STRUCT_TYPE, + type=_FixEmptyStructArrays.get_empty_struct_type(), ) elif isinstance(arr, pa.StructArray): diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index ff090f642e..98f91fadb7 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -3,8 +3,6 @@ import datetime from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Iterator -import pyarrow - from daft.dataframe.display import MermaidOptions from daft.execution import physical_plan from daft.io.scan import ScanOperator @@ -994,7 +992,7 @@ class PyDataType: def tensor(dtype: PyDataType, shape: tuple[int, ...] | None = None) -> PyDataType: ... @staticmethod def python() -> PyDataType: ... - def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pyarrow.DataType: ... + def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pa.DataType: ... def is_numeric(self) -> builtins.bool: ... def is_image(self) -> builtins.bool: ... def is_fixed_shape_image(self) -> builtins.bool: ... @@ -1271,11 +1269,11 @@ class PyCatalog: class PySeries: @staticmethod - def from_arrow(name: str, pyarrow_array: pyarrow.Array) -> PySeries: ... + def from_arrow(name: str, pyarrow_array: pa.Array) -> PySeries: ... @staticmethod def from_pylist(name: str, pylist: list[Any], pyobj: str) -> PySeries: ... def to_pylist(self) -> list[Any]: ... - def to_arrow(self) -> pyarrow.Array: ... + def to_arrow(self) -> pa.Array: ... def __abs__(self) -> PySeries: ... def __add__(self, other: PySeries) -> PySeries: ... def __sub__(self, other: PySeries) -> PySeries: ... @@ -1456,10 +1454,10 @@ class PyTable: def concat(tables: list[PyTable]) -> PyTable: ... def slice(self, start: int, end: int) -> PyTable: ... @staticmethod - def from_arrow_record_batches(record_batches: list[pyarrow.RecordBatch], schema: PySchema) -> PyTable: ... + def from_arrow_record_batches(record_batches: list[pa.RecordBatch], schema: PySchema) -> PyTable: ... @staticmethod def from_pylist_series(dict: dict[str, PySeries]) -> PyTable: ... - def to_arrow_record_batch(self) -> pyarrow.RecordBatch: ... + def to_arrow_record_batch(self) -> pa.RecordBatch: ... @staticmethod def empty(schema: PySchema | None = None) -> PyTable: ... @@ -1476,7 +1474,7 @@ class PyMicroPartition: @staticmethod def from_tables(tables: list[PyTable]) -> PyMicroPartition: ... @staticmethod - def from_arrow_record_batches(record_batches: list[pyarrow.RecordBatch], schema: PySchema) -> PyMicroPartition: ... + def from_arrow_record_batches(record_batches: list[pa.RecordBatch], schema: PySchema) -> PyMicroPartition: ... @staticmethod def concat(tables: list[PyMicroPartition]) -> PyMicroPartition: ... def slice(self, start: int, end: int) -> PyMicroPartition: ... diff --git a/daft/dataframe/preview.py b/daft/dataframe/preview.py index ec8129c066..c6e04e0045 100644 --- a/daft/dataframe/preview.py +++ b/daft/dataframe/preview.py @@ -1,8 +1,10 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING -from daft.table import MicroPartition +if TYPE_CHECKING: + from daft.table import MicroPartition @dataclass(frozen=True) diff --git a/daft/datatype.py b/daft/datatype.py index ac6b6dffc7..6d2eb1fe6b 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -1,14 +1,14 @@ from __future__ import annotations -import builtins from typing import TYPE_CHECKING -import pyarrow as pa - from daft.context import get_context from daft.daft import ImageMode, PyDataType, PyTimeUnit +from daft.dependencies import pa if TYPE_CHECKING: + import builtins + import numpy as np @@ -501,25 +501,40 @@ def __hash__(self) -> int: return self._dtype.__hash__() -class DaftExtension(pa.ExtensionType): - def __init__(self, dtype, metadata=b""): - # attributes need to be set first before calling - # super init (as that calls serialize) - self._metadata = metadata - super().__init__(dtype, "daft.super_extension") +_EXT_TYPE_REGISTERED = False +_STATIC_DAFT_EXTENSION = None - def __reduce__(self): - return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) - def __arrow_ext_serialize__(self): - return self._metadata +def _ensure_registered_super_ext_type(): + global _EXT_TYPE_REGISTERED + global _STATIC_DAFT_EXTENSION + if not _EXT_TYPE_REGISTERED: - @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized): - return cls(storage_type, serialized) + class DaftExtension(pa.ExtensionType): + def __init__(self, dtype, metadata=b""): + # attributes need to be set first before calling + # super init (as that calls serialize) + self._metadata = metadata + super().__init__(dtype, "daft.super_extension") + + def __reduce__(self): + return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) + + def __arrow_ext_serialize__(self): + return self._metadata + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls(storage_type, serialized) + + _STATIC_DAFT_EXTENSION = DaftExtension + pa.register_extension_type(DaftExtension(pa.null())) + import atexit + atexit.register(lambda: pa.unregister_extension_type("daft.super_extension")) + _EXT_TYPE_REGISTERED = True -pa.register_extension_type(DaftExtension(pa.null())) -import atexit -atexit.register(lambda: pa.unregister_extension_type("daft.super_extension")) +def get_super_ext_type(): + _ensure_registered_super_ext_type() + return _STATIC_DAFT_EXTENSION diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index 0a52b52d30..eb6973f24d 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -2,7 +2,7 @@ import logging import os -from collections.abc import Iterator +from typing import TYPE_CHECKING from deltalake.table import DeltaTable @@ -20,6 +20,9 @@ from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema +if TYPE_CHECKING: + from collections.abc import Iterator + logger = logging.getLogger(__name__) diff --git a/daft/dependencies.py b/daft/dependencies.py new file mode 100644 index 0000000000..e4c692b4bf --- /dev/null +++ b/daft/dependencies.py @@ -0,0 +1,32 @@ +from typing import TYPE_CHECKING + +from daft.lazy_import import LazyImport + +if TYPE_CHECKING: + import xml.etree.ElementTree as ET + + import fsspec + import numpy as np + import pandas as pd + import PIL.Image as pil_image + import pyarrow as pa + import pyarrow.csv as pacsv + import pyarrow.dataset as pads + import pyarrow.fs as pafs + import pyarrow.json as pajson + import pyarrow.parquet as pq +else: + ET = LazyImport("xml.etree.ElementTree") + + fsspec = LazyImport("fsspec") + np = LazyImport("numpy") + pd = LazyImport("pandas") + pil_image = LazyImport("PIL.Image") + pa = LazyImport("pyarrow") + pacsv = LazyImport("pyarrow.csv") + pads = LazyImport("pyarrow.dataset") + pafs = LazyImport("pyarrow.fs") + pajson = LazyImport("pyarrow.json") + pq = LazyImport("pyarrow.parquet") + +unity_catalog = LazyImport("daft.unity_catalog") diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 9c1241673e..95b57e9d10 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -1,15 +1,12 @@ from __future__ import annotations import itertools -import pathlib from dataclasses import dataclass, field from typing import TYPE_CHECKING, Generic, Protocol from daft.context import get_context -from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest, ScanTask +from daft.daft import ResourceRequest from daft.expressions import Expression, ExpressionsProjection, col -from daft.logical.map_partition_ops import MapPartitionOp -from daft.logical.schema import Schema from daft.runners.partitioning import ( Boundaries, MaterializedResult, @@ -20,9 +17,15 @@ from daft.table import MicroPartition, table_io if TYPE_CHECKING: + import pathlib + from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties + from daft.daft import FileFormat, IOConfig, JoinType, ScanTask + from daft.logical.map_partition_ops import MapPartitionOp + from daft.logical.schema import Schema + ID_GEN = itertools.count() diff --git a/daft/execution/native_executor.py b/daft/execution/native_executor.py index 2fe5cf9eba..6aa95a09f8 100644 --- a/daft/execution/native_executor.py +++ b/daft/execution/native_executor.py @@ -6,14 +6,14 @@ NativeExecutor as _NativeExecutor, ) from daft.daft import PyDaftExecutionConfig -from daft.logical.builder import LogicalPlanBuilder -from daft.runners.partitioning import ( - MaterializedResult, - PartitionT, -) from daft.table import MicroPartition if TYPE_CHECKING: + from daft.logical.builder import LogicalPlanBuilder + from daft.runners.partitioning import ( + MaterializedResult, + PartitionT, + ) from daft.runners.pyrunner import PyMaterializedResult diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 16797e3b2b..34da186731 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -17,7 +17,6 @@ import itertools import logging import math -import pathlib from collections import deque from typing import ( TYPE_CHECKING, @@ -30,7 +29,7 @@ ) from daft.context import get_context -from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest +from daft.daft import ResourceRequest from daft.execution import execution_step from daft.execution.execution_step import ( Instruction, @@ -41,7 +40,6 @@ SingleOutputPartitionTask, ) from daft.expressions import ExpressionsProjection -from daft.logical.schema import Schema from daft.runners.partitioning import ( MaterializedResult, PartitionT, @@ -53,9 +51,14 @@ T = TypeVar("T") if TYPE_CHECKING: + import pathlib + from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties + from daft.daft import FileFormat, IOConfig, JoinType + from daft.logical.schema import Schema + # A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks. InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]] diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 833887893c..3c85ad4149 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -17,12 +17,13 @@ from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema from daft.runners.partitioning import PartitionT -from daft.table import MicroPartition if TYPE_CHECKING: from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties + from daft.table import MicroPartition + def scan_with_tasks( scan_tasks: list[ScanTask], diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index f322d3b4e7..ede945f586 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1,6 +1,5 @@ from __future__ import annotations -import builtins import math import os from datetime import date, datetime, time @@ -16,8 +15,6 @@ overload, ) -import pyarrow as pa - import daft.daft as native from daft import context from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest, bind_stateful_udfs @@ -38,11 +35,14 @@ from daft.daft import url_download as _url_download from daft.daft import utf8_count_matches as _utf8_count_matches from daft.datatype import DataType, TimeUnit +from daft.dependencies import pa from daft.expressions.testing import expr_structurally_equal from daft.logical.schema import Field, Schema from daft.series import Series, item_to_series if TYPE_CHECKING: + import builtins + from daft.io import IOConfig from daft.udf import PartialStatefulUDF, PartialStatelessUDF # This allows Sphinx to correctly work against our "namespaced" accessor functions by overriding @property to diff --git a/daft/filesystem.py b/daft/filesystem.py index 06a8ee60e9..93ed1971fd 100644 --- a/daft/filesystem.py +++ b/daft/filesystem.py @@ -8,20 +8,16 @@ import urllib.parse from typing import Any, Literal -import fsspec -from fsspec.registry import get_filesystem_class -from pyarrow.fs import FileSystem, LocalFileSystem, S3FileSystem -from pyarrow.fs import _resolve_filesystem_and_path as pafs_resolve_filesystem_and_path - from daft.daft import FileFormat, FileInfos, IOConfig, io_glob +from daft.dependencies import fsspec, pafs from daft.table import MicroPartition logger = logging.getLogger(__name__) -_CACHED_FSES: dict[tuple[str, IOConfig | None], FileSystem] = {} +_CACHED_FSES: dict[tuple[str, IOConfig | None], pafs.FileSystem] = {} -def _get_fs_from_cache(protocol: str, io_config: IOConfig | None) -> FileSystem | None: +def _get_fs_from_cache(protocol: str, io_config: IOConfig | None) -> pafs.FileSystem | None: """ Get an instantiated pyarrow filesystem from the cache based on the URI protocol. @@ -32,7 +28,7 @@ def _get_fs_from_cache(protocol: str, io_config: IOConfig | None) -> FileSystem return _CACHED_FSES.get((protocol, io_config)) -def _put_fs_in_cache(protocol: str, fs: FileSystem, io_config: IOConfig | None) -> None: +def _put_fs_in_cache(protocol: str, fs: pafs.FileSystem, io_config: IOConfig | None) -> None: """Put pyarrow filesystem in cache under provided protocol.""" global _CACHED_FSES @@ -115,7 +111,7 @@ def canonicalize_protocol(protocol: str) -> str: def _resolve_paths_and_filesystem( paths: str | pathlib.Path | list[str], io_config: IOConfig | None = None, -) -> tuple[list[str], FileSystem]: +) -> tuple[list[str], pafs.FileSystem]: """ Resolves and normalizes all provided paths, infers a filesystem from the paths, and ensures that all paths use the same filesystem. @@ -166,7 +162,7 @@ def _resolve_paths_and_filesystem( # filesystem should be a non-None pyarrow FileSystem at this point, either # user-provided, taken from the cache, or inferred from the first path. - assert resolved_filesystem is not None and isinstance(resolved_filesystem, FileSystem) + assert resolved_filesystem is not None and isinstance(resolved_filesystem, pafs.FileSystem) # Resolve all other paths and validate with the user-provided/cached/inferred filesystem. resolved_paths = [resolved_path] @@ -177,7 +173,7 @@ def _resolve_paths_and_filesystem( return resolved_paths, resolved_filesystem -def _validate_filesystem(path: str, fs: FileSystem, io_config: IOConfig | None) -> str: +def _validate_filesystem(path: str, fs: pafs.FileSystem, io_config: IOConfig | None) -> str: resolved_path, inferred_fs = _infer_filesystem(path, io_config) if not isinstance(fs, type(inferred_fs)): raise RuntimeError( @@ -189,7 +185,7 @@ def _validate_filesystem(path: str, fs: FileSystem, io_config: IOConfig | None) def _infer_filesystem( path: str, io_config: IOConfig | None, -) -> tuple[str, FileSystem]: +) -> tuple[str, pafs.FileSystem]: """ Resolves and normalizes the provided path, infers a filesystem from the path, and ensures that the inferred filesystem is compatible with the passed @@ -229,7 +225,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): except ImportError: pass # Config does not exist in pyarrow 7.0.0 - resolved_filesystem = S3FileSystem(**translated_kwargs) + resolved_filesystem = pafs.S3FileSystem(**translated_kwargs) resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path)) return resolved_path, resolved_filesystem @@ -237,7 +233,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): # Local ### elif protocol == "file": - resolved_filesystem = LocalFileSystem() + resolved_filesystem = pafs.LocalFileSystem() resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path)) return resolved_path, resolved_filesystem @@ -267,9 +263,9 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): # HTTP: Use FSSpec as a fallback ### elif protocol in {"http", "https"}: - fsspec_fs_cls = get_filesystem_class(protocol) + fsspec_fs_cls = fsspec.get_filesystem_class(protocol) fsspec_fs = fsspec_fs_cls() - resolved_filesystem, resolved_path = pafs_resolve_filesystem_and_path(path, fsspec_fs) + resolved_filesystem, resolved_path = pafs._resolve_filesystem_and_path(path, fsspec_fs) resolved_path = resolved_filesystem.normalize_path(resolved_path) return resolved_path, resolved_filesystem @@ -277,7 +273,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): # Azure: Use FSSpec as a fallback ### elif protocol in {"az", "abfs", "abfss"}: - fsspec_fs_cls = get_filesystem_class(protocol) + fsspec_fs_cls = fsspec.get_filesystem_class(protocol) if io_config is not None: # TODO: look into support for other AzureConfig parameters @@ -292,7 +288,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None): ) else: fsspec_fs = fsspec_fs_cls() - resolved_filesystem, resolved_path = pafs_resolve_filesystem_and_path(path, fsspec_fs) + resolved_filesystem, resolved_path = pafs._resolve_filesystem_and_path(path, fsspec_fs) resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(resolved_path)) return resolved_path, resolved_filesystem @@ -348,12 +344,12 @@ def glob_path_with_stats( ### -def join_path(fs: FileSystem, base_path: str, *sub_paths: str) -> str: +def join_path(fs: pafs.FileSystem, base_path: str, *sub_paths: str) -> str: """ Join a base path with sub-paths using the appropriate path separator for the given filesystem. """ - if isinstance(fs, LocalFileSystem): + if isinstance(fs, pafs.LocalFileSystem): return os.path.join(base_path, *sub_paths) else: return f"{base_path.rstrip('/')}/{'/'.join(sub_paths)}" diff --git a/daft/hudi/hudi_scan.py b/daft/hudi/hudi_scan.py index 02821df1ae..3d87f9716a 100644 --- a/daft/hudi/hudi_scan.py +++ b/daft/hudi/hudi_scan.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from collections.abc import Iterator +from typing import TYPE_CHECKING import daft from daft.daft import ( @@ -16,6 +16,9 @@ from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema +if TYPE_CHECKING: + from collections.abc import Iterator + logger = logging.getLogger(__name__) diff --git a/daft/hudi/pyhudi/filegroup.py b/daft/hudi/pyhudi/filegroup.py index aa347e4681..2e3552bb91 100644 --- a/daft/hudi/pyhudi/filegroup.py +++ b/daft/hudi/pyhudi/filegroup.py @@ -1,11 +1,14 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import TYPE_CHECKING -import pyarrow as pa from sortedcontainers import SortedDict -from daft.hudi.pyhudi.utils import FsFileMetadata +from daft.dependencies import pa + +if TYPE_CHECKING: + from daft.hudi.pyhudi.utils import FsFileMetadata @dataclass(init=False) diff --git a/daft/hudi/pyhudi/table.py b/daft/hudi/pyhudi/table.py index a37d77d2d7..7597072f9a 100644 --- a/daft/hudi/pyhudi/table.py +++ b/daft/hudi/pyhudi/table.py @@ -3,9 +3,7 @@ from collections import defaultdict from dataclasses import dataclass -import pyarrow as pa -import pyarrow.fs as pafs - +from daft.dependencies import pa, pafs from daft.filesystem import join_path from daft.hudi.pyhudi.filegroup import BaseFile, FileGroup, FileSlice from daft.hudi.pyhudi.timeline import Timeline diff --git a/daft/hudi/pyhudi/timeline.py b/daft/hudi/pyhudi/timeline.py index f16a3dadc7..cf7714291d 100644 --- a/daft/hudi/pyhudi/timeline.py +++ b/daft/hudi/pyhudi/timeline.py @@ -5,10 +5,7 @@ from dataclasses import dataclass from enum import Enum -import pyarrow as pa -import pyarrow.fs as pafs -import pyarrow.parquet as pq - +from daft.dependencies import pa, pafs, pq from daft.filesystem import join_path diff --git a/daft/hudi/pyhudi/utils.py b/daft/hudi/pyhudi/utils.py index 9d53938e54..94da0ec371 100644 --- a/daft/hudi/pyhudi/utils.py +++ b/daft/hudi/pyhudi/utils.py @@ -3,10 +3,7 @@ import os from dataclasses import dataclass -import pyarrow as pa -import pyarrow.fs as pafs -import pyarrow.parquet as pq - +from daft.dependencies import pa, pafs, pq from daft.filesystem import join_path diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index 58241ca217..ee60eda888 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -2,16 +2,11 @@ import logging import warnings -from collections.abc import Iterator +from typing import TYPE_CHECKING -import pyarrow as pa from pyiceberg.io.pyarrow import schema_to_pyarrow -from pyiceberg.partitioning import PartitionField as IcebergPartitionField -from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.schema import visit -from pyiceberg.table import Table -from pyiceberg.typedef import Record import daft from daft.daft import ( @@ -23,10 +18,19 @@ StorageConfig, ) from daft.datatype import DataType +from daft.dependencies import pa from daft.iceberg.schema_field_id_mapping_visitor import SchemaFieldIdMappingVisitor from daft.io.scan import PartitionField, ScanOperator, make_partition_field from daft.logical.schema import Field, Schema +if TYPE_CHECKING: + from collections.abc import Iterator + + from pyiceberg.partitioning import PartitionField as IcebergPartitionField + from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec + from pyiceberg.table import Table + from pyiceberg.typedef import Record + logger = logging.getLogger(__name__) diff --git a/daft/iceberg/schema_field_id_mapping_visitor.py b/daft/iceberg/schema_field_id_mapping_visitor.py index be631b49fc..d48b85510b 100644 --- a/daft/iceberg/schema_field_id_mapping_visitor.py +++ b/daft/iceberg/schema_field_id_mapping_visitor.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import Dict +from typing import TYPE_CHECKING, Dict from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema, SchemaVisitor -from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType from daft import DataType from daft.daft import PyField +if TYPE_CHECKING: + from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType + FieldIdMapping = Dict[int, PyField] diff --git a/daft/internal/gpu.py b/daft/internal/gpu.py index 01463b4958..fb0a25565a 100644 --- a/daft/internal/gpu.py +++ b/daft/internal/gpu.py @@ -1,7 +1,8 @@ from __future__ import annotations import subprocess -import xml.etree.ElementTree as ET + +from daft.dependencies import ET def cuda_device_count(): diff --git a/daft/io/_deltalake.py b/daft/io/_deltalake.py index 765f905cdd..c4530bcd98 100644 --- a/daft/io/_deltalake.py +++ b/daft/io/_deltalake.py @@ -1,19 +1,17 @@ # isort: dont-add-import: from __future__ import annotations -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from daft import context from daft.api_annotations import PublicAPI from daft.daft import IOConfig, NativeStorageConfig, ScanOperatorHandle, StorageConfig from daft.dataframe import DataFrame +from daft.dependencies import unity_catalog from daft.io.catalog import DataCatalogTable from daft.logical.builder import LogicalPlanBuilder -_UNITY_CATALOG_AVAILABLE = True -try: +if TYPE_CHECKING: from daft.unity_catalog import UnityCatalogTable -except ImportError: - _UNITY_CATALOG_AVAILABLE = False @PublicAPI @@ -60,7 +58,7 @@ def read_deltalake( table_uri = table elif isinstance(table, DataCatalogTable): table_uri = table.table_uri(io_config) - elif _UNITY_CATALOG_AVAILABLE and isinstance(table, UnityCatalogTable): + elif unity_catalog.module_available() and isinstance(table, unity_catalog.UnityCatalogTable): table_uri = table.table_uri # Override the storage_config with the one provided by Unity catalog diff --git a/daft/io/common.py b/daft/io/common.py index 363f4812ec..d4b34291c4 100644 --- a/daft/io/common.py +++ b/daft/io/common.py @@ -3,12 +3,11 @@ from typing import TYPE_CHECKING from daft.daft import FileFormatConfig, ScanOperatorHandle, StorageConfig -from daft.datatype import DataType from daft.logical.builder import LogicalPlanBuilder from daft.logical.schema import Schema if TYPE_CHECKING: - pass + from daft.datatype import DataType def _get_schema_from_dict(fields: dict[str, DataType]) -> Schema: diff --git a/daft/io/object_store_options.py b/daft/io/object_store_options.py index 46e4fe3b13..65855455b7 100644 --- a/daft/io/object_store_options.py +++ b/daft/io/object_store_options.py @@ -1,8 +1,10 @@ from __future__ import annotations +from typing import TYPE_CHECKING from urllib.parse import urlparse -from daft.daft import AzureConfig, GCSConfig, IOConfig, S3Config +if TYPE_CHECKING: + from daft.daft import AzureConfig, GCSConfig, IOConfig, S3Config def io_config_to_storage_options(io_config: IOConfig, table_uri: str) -> dict[str, str] | None: diff --git a/daft/io/scan.py b/daft/io/scan.py index 3cb074d25d..132bd63d64 100644 --- a/daft/io/scan.py +++ b/daft/io/scan.py @@ -1,10 +1,14 @@ from __future__ import annotations import abc -from collections.abc import Iterator +from typing import TYPE_CHECKING from daft.daft import PartitionField, PartitionTransform, Pushdowns, ScanTask -from daft.logical.schema import Field, Schema + +if TYPE_CHECKING: + from collections.abc import Iterator + + from daft.logical.schema import Field, Schema def make_partition_field( diff --git a/daft/lazy_import.py b/daft/lazy_import.py new file mode 100644 index 0000000000..cfeac7ac48 --- /dev/null +++ b/daft/lazy_import.py @@ -0,0 +1,62 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Borrowed and modified from [`skypilot`](https://github.com/skypilot-org/skypilot/blob/master/sky/adaptors/common.py). + +import importlib +from typing import Any + + +class LazyImport: + """Lazy importer + There are certain large imports (e.g. Ray, daft.unity_catalog.UnityCatalogTable, etc.) that + do not need to be top-level imports. For example, Ray should only be imported when the ray + runner is used, or specific ray data extension types are needed. We can lazily import these + modules as needed. + """ + + def __init__(self, module_name: str): + self._module_name = module_name + self._module = None + + def module_available(self): + return self._load_module() is not None + + def _load_module(self): + if self._module is None: + try: + self._module = importlib.import_module(self._module_name) + except ImportError: + pass + return self._module + + def __getattr__(self, name: str) -> Any: + # Given a lazy module and an attribute to get, we have the following possibilities: + # 1. The attribute is the lazy object's attribute. + # 2. The attribute is an attribute of the module. + # 3. The module does not exist. + # 4. The attribute is a submodule. + # 5. The attribute does not exist. + try: + if name in self.__dict__: + return self.__dict__[name] + return getattr(self._load_module(), name) + except AttributeError as e: + if self._module is None: + raise e + # Dynamically create a new LazyImport instance for the submodule. + submodule_name = f"{self._module_name}.{name}" + lazy_submodule = LazyImport(submodule_name) + if lazy_submodule.module_available(): + setattr(self, name, lazy_submodule) + return lazy_submodule + raise e diff --git a/daft/logical/builder.py b/daft/logical/builder.py index b2717df2f6..218e521d6d 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -1,7 +1,6 @@ from __future__ import annotations import functools -import pathlib from typing import TYPE_CHECKING, Callable from daft.context import get_context @@ -17,15 +16,17 @@ from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder from daft.expressions import Expression, col from daft.logical.schema import Schema -from daft.runners.partitioning import PartitionCacheEntry if TYPE_CHECKING: + import pathlib + from pyiceberg.table import Table as IcebergTable from daft.plan_scheduler.physical_plan_scheduler import ( AdaptivePhysicalPlanScheduler, PhysicalPlanScheduler, ) + from daft.runners.partitioning import PartitionCacheEntry def _apply_daft_planning_config_to_initializer(classmethod_func: Callable[..., LogicalPlanBuilder]): diff --git a/daft/logical/map_partition_ops.py b/daft/logical/map_partition_ops.py index 46819a27bc..b77ca97352 100644 --- a/daft/logical/map_partition_ops.py +++ b/daft/logical/map_partition_ops.py @@ -1,10 +1,13 @@ from __future__ import annotations from abc import abstractmethod +from typing import TYPE_CHECKING from daft.expressions import ExpressionsProjection from daft.logical.schema import Schema -from daft.table import MicroPartition + +if TYPE_CHECKING: + from daft.table import MicroPartition class MapPartitionOp: diff --git a/daft/plan_scheduler/physical_plan_scheduler.py b/daft/plan_scheduler/physical_plan_scheduler.py index 24470d63f1..43bab81dbe 100644 --- a/daft/plan_scheduler/physical_plan_scheduler.py +++ b/daft/plan_scheduler/physical_plan_scheduler.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from daft.daft import ( AdaptivePhysicalPlanScheduler as _AdaptivePhysicalPlanScheduler, ) @@ -8,11 +10,13 @@ PyDaftExecutionConfig, ) from daft.execution import physical_plan -from daft.logical.builder import LogicalPlanBuilder -from daft.runners.partitioning import ( - PartitionCacheEntry, - PartitionT, -) + +if TYPE_CHECKING: + from daft.logical.builder import LogicalPlanBuilder + from daft.runners.partitioning import ( + PartitionCacheEntry, + PartitionT, + ) class PhysicalPlanScheduler: diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 6f4269d27c..74903d04e6 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -7,15 +7,15 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from uuid import uuid4 -import pyarrow as pa - from daft.datatype import TimeUnit -from daft.expressions.expressions import Expression -from daft.logical.schema import Schema -from daft.table import MicroPartition if TYPE_CHECKING: import pandas as pd + import pyarrow as pa + + from daft.expressions.expressions import Expression + from daft.logical.schema import Schema + from daft.table import MicroPartition PartID = int diff --git a/daft/runners/progress_bar.py b/daft/runners/progress_bar.py index 036138ff4c..e5f7172317 100644 --- a/daft/runners/progress_bar.py +++ b/daft/runners/progress_bar.py @@ -2,9 +2,10 @@ import os import time -from typing import Any +from typing import TYPE_CHECKING, Any -from daft.execution.execution_step import PartitionTask +if TYPE_CHECKING: + from daft.execution.execution_step import PartitionTask class ProgressBar: diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index c2f35e648e..e80acb03cb 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -6,17 +6,14 @@ import uuid from concurrent import futures from dataclasses import dataclass -from typing import Iterator +from typing import TYPE_CHECKING, Iterator from daft.context import get_context from daft.daft import FileFormatConfig, FileInfos, IOConfig, ResourceRequest, SystemInfo -from daft.execution import physical_plan -from daft.execution.execution_step import Instruction, PartitionTask from daft.execution.native_executor import NativeExecutor from daft.expressions import ExpressionsProjection from daft.filesystem import glob_path_with_stats from daft.internal.gpu import cuda_device_count -from daft.logical.builder import LogicalPlanBuilder from daft.runners import runner_io from daft.runners.partitioning import ( MaterializedResult, @@ -30,7 +27,12 @@ from daft.runners.progress_bar import ProgressBar from daft.runners.runner import Runner from daft.table import MicroPartition -from daft.udf import UserProvidedPythonFunction + +if TYPE_CHECKING: + from daft.execution import physical_plan + from daft.execution.execution_step import Instruction, PartitionTask + from daft.logical.builder import LogicalPlanBuilder + from daft.udf import UserProvidedPythonFunction logger = logging.getLogger(__name__) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 6f6611c067..84458815cb 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -9,14 +9,15 @@ from queue import Full, Queue from typing import TYPE_CHECKING, Any, Generator, Iterable, Iterator -import pyarrow as pa +# The ray runner is not a top-level module, so we don't need to lazily import pyarrow to minimize +# import times. If this changes, we first need to make the daft.lazy_import.LazyImport class +# serializable before importing pa from daft.dependencies. +import pyarrow as pa # noqa: TID253 from daft.arrow_utils import ensure_array from daft.context import execution_config_ctx, get_context from daft.daft import PyTable as _PyTable -from daft.expressions import ExpressionsProjection -from daft.logical.builder import LogicalPlanBuilder -from daft.plan_scheduler import PhysicalPlanScheduler +from daft.dependencies import np from daft.runners.progress_bar import ProgressBar from daft.series import Series, item_to_series from daft.table import Table @@ -68,6 +69,10 @@ from ray.data.block import Block as RayDatasetBlock from ray.data.dataset import Dataset as RayDataset + from daft.expressions import ExpressionsProjection + from daft.logical.builder import LogicalPlanBuilder + from daft.plan_scheduler import PhysicalPlanScheduler + _RAY_FROM_ARROW_REFS_AVAILABLE = True try: from ray.data import from_arrow_refs @@ -108,18 +113,6 @@ except ImportError: _RAY_DATA_EXTENSIONS_AVAILABLE = False -_NUMPY_AVAILABLE = True -try: - import numpy as np -except ImportError: - _NUMPY_AVAILABLE = False - -_PANDAS_AVAILABLE = True -try: - import pandas as pd -except ImportError: - _PANDAS_AVAILABLE = False - @ray.remote def _glob_path_into_file_infos( diff --git a/daft/runners/runner.py b/daft/runners/runner.py index b99d4e352b..c1dd30f64e 100644 --- a/daft/runners/runner.py +++ b/daft/runners/runner.py @@ -2,11 +2,8 @@ import contextlib from abc import abstractmethod -from typing import Generic, Iterator +from typing import TYPE_CHECKING, Generic, Iterator -from daft.daft import ResourceRequest -from daft.expressions import ExpressionsProjection -from daft.logical.builder import LogicalPlanBuilder from daft.runners.partitioning import ( MaterializedResult, PartitionCacheEntry, @@ -14,8 +11,13 @@ PartitionSetCache, PartitionT, ) -from daft.runners.runner_io import RunnerIO -from daft.table import MicroPartition + +if TYPE_CHECKING: + from daft.daft import ResourceRequest + from daft.expressions import ExpressionsProjection + from daft.logical.builder import LogicalPlanBuilder + from daft.runners.runner_io import RunnerIO + from daft.table import MicroPartition class Runner(Generic[PartitionT]): diff --git a/daft/runners/runner_io.py b/daft/runners/runner_io.py index 980a855875..416a997a98 100644 --- a/daft/runners/runner_io.py +++ b/daft/runners/runner_io.py @@ -3,10 +3,8 @@ from abc import abstractmethod from typing import TYPE_CHECKING -from daft.daft import FileFormatConfig, FileInfos, IOConfig - if TYPE_CHECKING: - pass + from daft.daft import FileFormatConfig, FileInfos, IOConfig class RunnerIO: diff --git a/daft/series.py b/daft/series.py index a670d319f4..440d570b4a 100644 --- a/daft/series.py +++ b/daft/series.py @@ -2,27 +2,12 @@ from typing import Any, Literal, TypeVar -import pyarrow as pa - from daft.arrow_utils import ensure_array, ensure_chunked_array from daft.daft import CountMode, ImageFormat, ImageMode, PySeries, image -from daft.datatype import DataType +from daft.datatype import DataType, _ensure_registered_super_ext_type +from daft.dependencies import np, pa, pd from daft.utils import pyarrow_supports_fixed_shape_tensor -_NUMPY_AVAILABLE = True -try: - import numpy as np -except ImportError: - _NUMPY_AVAILABLE = False - -_PANDAS_AVAILABLE = True -try: - import pandas as pd -except ImportError: - _PANDAS_AVAILABLE = False - -ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) - class Series: """ @@ -49,6 +34,8 @@ def from_arrow(array: pa.Array | pa.ChunkedArray, name: str = "arrow_series") -> array: The pyarrow (chunked) array whose data we wish to put in the Series. name: The name associated with the Series; this is usually the column name. """ + + _ensure_registered_super_ext_type() if DataType.from_arrow_type(array.type) == DataType.python(): # If the Arrow type is not natively supported, go through the Python list path. return Series.from_pylist(array.to_pylist(), name=name, pyobj="force") @@ -640,13 +627,13 @@ def _debug_bincode_deserialize(cls, b: bytes) -> Series: def item_to_series(name: str, item: Any) -> Series: if isinstance(item, list): series = Series.from_pylist(item, name) - elif _NUMPY_AVAILABLE and isinstance(item, np.ndarray): + elif np.module_available() and isinstance(item, np.ndarray): series = Series.from_numpy(item, name) elif isinstance(item, Series): series = item elif isinstance(item, (pa.Array, pa.ChunkedArray)): series = Series.from_arrow(item, name) - elif _PANDAS_AVAILABLE and isinstance(item, pd.Series): + elif pd.module_available() and isinstance(item, pd.Series): series = Series.from_pandas(item, name) else: raise ValueError(f"Creating a Series from data of type {type(item)} not implemented") diff --git a/daft/sql/sql_connection.py b/daft/sql/sql_connection.py index 6e91fda4ad..6335951a82 100644 --- a/daft/sql/sql_connection.py +++ b/daft/sql/sql_connection.py @@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Callable from urllib.parse import urlparse -import pyarrow as pa - +from daft.dependencies import pa from daft.logical.schema import Schema if TYPE_CHECKING: from sqlalchemy.engine import Connection + logger = logging.getLogger(__name__) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 67b5629cf2..40035c12f7 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -3,9 +3,8 @@ import logging import math import warnings -from collections.abc import Iterator from enum import Enum, auto -from typing import Any +from typing import TYPE_CHECKING, Any from daft.context import get_context from daft.daft import ( @@ -16,14 +15,18 @@ ScanTask, StorageConfig, ) -from daft.datatype import DataType from daft.expressions.expressions import lit from daft.io.common import _get_schema_from_dict from daft.io.scan import PartitionField, ScanOperator -from daft.logical.schema import Schema -from daft.sql.sql_connection import SQLConnection from daft.table import Table +if TYPE_CHECKING: + from collections.abc import Iterator + + from daft.datatype import DataType + from daft.logical.schema import Schema + from daft.sql.sql_connection import SQLConnection + logger = logging.getLogger(__name__) diff --git a/daft/table/micropartition.py b/daft/table/micropartition.py index 0646c8c7ef..81b2b8b2ac 100644 --- a/daft/table/micropartition.py +++ b/daft/table/micropartition.py @@ -3,8 +3,6 @@ import logging from typing import TYPE_CHECKING, Any -import pyarrow as pa - from daft.daft import ( CsvConvertOptions, CsvParseOptions, @@ -26,14 +24,7 @@ if TYPE_CHECKING: import pandas as pd - - -_PANDAS_AVAILABLE = True -try: - import pandas as pd -except ImportError: - _PANDAS_AVAILABLE = False - + import pyarrow as pa logger = logging.getLogger(__name__) diff --git a/daft/table/schema_inference.py b/daft/table/schema_inference.py index 2c4ba7212c..66b76bbbec 100644 --- a/daft/table/schema_inference.py +++ b/daft/table/schema_inference.py @@ -2,10 +2,6 @@ import pathlib -import pyarrow.csv as pacsv -import pyarrow.json as pajson -import pyarrow.parquet as papq - from daft.daft import ( CsvParseOptions, JsonParseOptions, @@ -14,6 +10,7 @@ StorageConfig, ) from daft.datatype import DataType +from daft.dependencies import pacsv, pajson, pq from daft.filesystem import _resolve_paths_and_filesystem from daft.logical.schema import Schema from daft.runners.partitioning import TableParseCSVOptions @@ -137,7 +134,7 @@ def from_parquet( path = paths[0] f = fs.open_input_file(path) - pqf = papq.ParquetFile(f) + pqf = pq.ParquetFile(f) arrow_schema = pqf.metadata.schema.to_arrow_schema() return Schema._from_field_name_and_types([(f.name, DataType.from_arrow_type(f.type)) for f in arrow_schema]) diff --git a/daft/table/table.py b/daft/table/table.py index 81b845ba5b..707ea6ec98 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -3,8 +3,6 @@ import logging from typing import TYPE_CHECKING, Any -import pyarrow as pa - from daft.arrow_utils import ensure_table from daft.daft import ( CsvConvertOptions, @@ -25,23 +23,14 @@ from daft.daft import read_parquet_into_pyarrow_bulk as _read_parquet_into_pyarrow_bulk from daft.daft import read_parquet_statistics as _read_parquet_statistics from daft.datatype import DataType, TimeUnit +from daft.dependencies import pa, pd from daft.expressions import Expression, ExpressionsProjection from daft.logical.schema import Schema from daft.series import Series, item_to_series -_PANDAS_AVAILABLE = True -try: - import pandas as pd -except ImportError: - _PANDAS_AVAILABLE = False - if TYPE_CHECKING: - import pandas as pd - import pyarrow as pa - from daft.io import IOConfig - logger = logging.getLogger(__name__) @@ -124,7 +113,7 @@ def from_arrow_record_batches(rbs: list[pa.RecordBatch], arrow_schema: pa.Schema @staticmethod def from_pandas(pd_df: pd.DataFrame) -> Table: - if not _PANDAS_AVAILABLE: + if not pd.module_available(): raise ImportError("Unable to import Pandas - please ensure that it is installed.") assert isinstance(pd_df, pd.DataFrame) try: @@ -190,7 +179,7 @@ def to_pandas( ) -> pd.DataFrame: from packaging.version import parse - if not _PANDAS_AVAILABLE: + if not pd.module_available(): raise ImportError("Unable to import Pandas - please ensure that it is installed.") python_fields = set() diff --git a/daft/table/table_io.py b/daft/table/table_io.py index d4bffee3b4..97c274a1f3 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -5,17 +5,10 @@ import pathlib import random import time -from collections.abc import Callable, Generator from functools import partial from typing import IO, TYPE_CHECKING, Any, Union from uuid import uuid4 -import pyarrow as pa -from pyarrow import csv as pacsv -from pyarrow import dataset as pads -from pyarrow import json as pajson -from pyarrow import parquet as papq - from daft.context import get_context from daft.daft import ( CsvConvertOptions, @@ -31,8 +24,8 @@ StorageConfig, ) from daft.datatype import DataType +from daft.dependencies import pa, pacsv, pads, pajson, pq from daft.expressions import ExpressionsProjection -from daft.expressions.expressions import Expression from daft.filesystem import ( _resolve_paths_and_filesystem, canonicalize_protocol, @@ -45,15 +38,19 @@ TableReadOptions, ) from daft.series import Series -from daft.sql.sql_connection import SQLConnection from daft.table import MicroPartition FileInput = Union[pathlib.Path, str, IO[bytes]] if TYPE_CHECKING: + from collections.abc import Callable, Generator + from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties + from daft.expressions.expressions import Expression + from daft.sql.sql_connection import SQLConnection + @contextlib.contextmanager def _open_stream( @@ -194,7 +191,7 @@ def read_parquet( # If no rows required, we manually construct an empty table with the right schema if read_options.num_rows == 0: - pqf = papq.ParquetFile( + pqf = pq.ParquetFile( f, coerce_int96_timestamp_unit=str(parquet_options.coerce_int96_timestamp_unit), ) @@ -204,7 +201,7 @@ def read_parquet( schema=arrow_schema, ) elif read_options.num_rows is not None: - pqf = papq.ParquetFile( + pqf = pq.ParquetFile( f, coerce_int96_timestamp_unit=str(parquet_options.coerce_int96_timestamp_unit), ) @@ -220,7 +217,7 @@ def read_parquet( # Need to truncate the table to the row limit. table = table.slice(length=read_options.num_rows) else: - table = papq.read_table( + table = pq.read_table( f, columns=read_options.column_names, coerce_int96_timestamp_unit=str(parquet_options.coerce_int96_timestamp_unit), @@ -337,9 +334,11 @@ def read_csv( io_config = config.io_config with _open_stream(file, io_config) as f: - from daft.utils import ARROW_VERSION + from daft.utils import get_arrow_version - if csv_options.comment is not None and ARROW_VERSION < (7, 0, 0): + arrow_version = get_arrow_version() + + if csv_options.comment is not None and arrow_version < (7, 0, 0): raise ValueError( "pyarrow < 7.0.0 doesn't support handling comments in CSVs, please upgrade pyarrow to 7.0.0+." ) @@ -350,7 +349,7 @@ def read_csv( escape_char=csv_options.escape_char, ) - if ARROW_VERSION >= (7, 0, 0): + if arrow_version >= (7, 0, 0): parse_options.invalid_row_handler = skip_comment(csv_options.comment) pacsv_stream = pacsv.open_csv( @@ -700,7 +699,7 @@ def write_deltalake( from daft.io._deltalake import large_dtypes_kwargs from daft.io.object_store_options import io_config_to_storage_options - from daft.utils import ARROW_VERSION + from daft.utils import get_arrow_version protocol = get_protocol_from_path(base_path) canonicalized_protocol = canonicalize_protocol(protocol) @@ -721,7 +720,7 @@ def file_visitor(written_file: Any) -> None: stats = get_file_stats_from_metadata(written_file.metadata) # PyArrow added support for written_file.size in 9.0.0 - if ARROW_VERSION >= (9, 0, 0): + if get_arrow_version() >= (9, 0, 0): size = written_file.size elif fs is not None: size = fs.get_file_info([path])[0].size @@ -852,14 +851,16 @@ def _write_tabular_arrow_table( ): kwargs = dict() - from daft.utils import ARROW_VERSION + from daft.utils import get_arrow_version + + arrow_version = get_arrow_version() - if ARROW_VERSION >= (7, 0, 0): + if arrow_version >= (7, 0, 0): kwargs["max_rows_per_file"] = rows_per_file kwargs["min_rows_per_group"] = rows_per_row_group kwargs["max_rows_per_group"] = rows_per_row_group - if ARROW_VERSION >= (8, 0, 0) and not create_dir: + if arrow_version >= (8, 0, 0) and not create_dir: kwargs["create_dir"] = False basename_template = _generate_basename_template(format.default_extname, version) @@ -906,7 +907,7 @@ def write_empty_tabular( def write_table(): if file_format == FileFormat.Parquet: - papq.write_table( + pq.write_table( table, file_path, compression=compression, diff --git a/daft/udf.py b/daft/udf.py index bfc0a4e570..d560ddd4f5 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -4,30 +4,15 @@ import functools import inspect from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import Any, Callable, Union from daft.context import get_context from daft.daft import PyDataType, ResourceRequest from daft.datatype import DataType +from daft.dependencies import np, pa from daft.expressions import Expression from daft.series import PySeries, Series -_NUMPY_AVAILABLE = True -try: - import numpy as np -except ImportError: - _NUMPY_AVAILABLE = False - -_PYARROW_AVAILABLE = True -try: - import pyarrow as pa -except ImportError: - _PYARROW_AVAILABLE = False - -if TYPE_CHECKING: - import numpy as np - import pyarrow as pa - UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]] @@ -169,10 +154,10 @@ def get_args_for_slice(start: int, end: int): return Series.from_pylist(result_list, name=name, pyobj="force")._series else: return Series.from_pylist(result_list, name=name, pyobj="allow").cast(return_dtype)._series - elif _NUMPY_AVAILABLE and isinstance(results[0], np.ndarray): + elif np.module_available() and isinstance(results[0], np.ndarray): result_np = np.concatenate(results) return Series.from_numpy(result_np, name=name).cast(return_dtype)._series - elif _PYARROW_AVAILABLE and isinstance(results[0], (pa.Array, pa.ChunkedArray)): + elif pa.module_available() and isinstance(results[0], (pa.Array, pa.ChunkedArray)): result_pa = pa.concat_arrays(results) return Series.from_arrow(result_pa, name=name).cast(return_dtype)._series else: diff --git a/daft/unity_catalog/__init__.py b/daft/unity_catalog/__init__.py index ce559ec33c..b77feac517 100644 --- a/daft/unity_catalog/__init__.py +++ b/daft/unity_catalog/__init__.py @@ -1,3 +1,5 @@ -from .unity_catalog import UnityCatalog, UnityCatalogTable +# We ban importing from daft.unity_catalog as a module level import because it is expensive despite +# not always being needed. Within the daft.unity_catalog module itself we ignore this restriction. +from .unity_catalog import UnityCatalog, UnityCatalogTable # noqa: TID253 __all__ = ["UnityCatalog", "UnityCatalogTable"] diff --git a/daft/utils.py b/daft/utils.py index 2a1ffd205f..6d7e10fb92 100644 --- a/daft/utils.py +++ b/daft/utils.py @@ -5,9 +5,11 @@ import statistics from typing import Any, Callable -import pyarrow as pa +from daft.dependencies import pa -ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + +def get_arrow_version(): + return tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) def in_notebook(): @@ -132,4 +134,4 @@ def pyarrow_supports_fixed_shape_tensor() -> bool: """Whether pyarrow supports the fixed_shape_tensor canonical extension type.""" from daft.context import get_context - return hasattr(pa, "fixed_shape_tensor") and (not get_context().is_ray_runner or ARROW_VERSION >= (13, 0, 0)) + return hasattr(pa, "fixed_shape_tensor") and (not get_context().is_ray_runner or get_arrow_version() >= (13, 0, 0)) diff --git a/daft/viz/dataframe_display.py b/daft/viz/dataframe_display.py index ba9c0c7165..2d8b459c3c 100644 --- a/daft/viz/dataframe_display.py +++ b/daft/viz/dataframe_display.py @@ -1,19 +1,11 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING -from daft.dataframe.preview import DataFramePreview -from daft.logical.schema import Schema - -HAS_PILLOW = False -try: - pass - - HAS_PILLOW = True -except ImportError: - pass -if HAS_PILLOW: - pass +if TYPE_CHECKING: + from daft.dataframe.preview import DataFramePreview + from daft.logical.schema import Schema @dataclass(frozen=True) diff --git a/daft/viz/html_viz_hooks.py b/daft/viz/html_viz_hooks.py index 0602cc6484..17418d30e2 100644 --- a/daft/viz/html_viz_hooks.py +++ b/daft/viz/html_viz_hooks.py @@ -2,16 +2,15 @@ import base64 import io -from typing import TYPE_CHECKING, Callable, TypeVar - -if TYPE_CHECKING: - import numpy as np - import PIL.Image +from typing import Callable, TypeVar +from daft.dependencies import np, pil_image HookClass = TypeVar("HookClass") _VIZ_HOOKS_REGISTRY = {} +_NUMPY_REGISTERED = False +_PILLOW_REGISTERED = False def register_viz_hook(klass: type[HookClass], hook: Callable[[object], str]): @@ -21,43 +20,30 @@ def register_viz_hook(klass: type[HookClass], hook: Callable[[object], str]): def get_viz_hook(val: object) -> Callable[[object], str] | None: - for klass in _VIZ_HOOKS_REGISTRY: - if isinstance(val, klass): - return _VIZ_HOOKS_REGISTRY[klass] - return None - - -### -# Default hooks, registered at import-time -### + global _NUMPY_REGISTERED + global _PILLOW_REGISTERED + if np.module_available() and not _NUMPY_REGISTERED: -HAS_PILLOW = True -try: - import PIL.Image -except ImportError: - HAS_PILLOW = False + def _viz_numpy(val: np.ndarray) -> str: + return f"<np.ndarray
shape={val.shape}
dtype={val.dtype}>" -HAS_NUMPY = True -try: - import numpy as np -except ImportError: - HAS_NUMPY = False + register_viz_hook(np.ndarray, _viz_numpy) + _NUMPY_REGISTERED = True -if HAS_PILLOW: + if pil_image.module_available() and not _PILLOW_REGISTERED: - def _viz_pil_image(val: PIL.Image.Image) -> str: - img = val.copy() - img.thumbnail((128, 128)) - bio = io.BytesIO() - img.save(bio, "JPEG") - base64_img = base64.b64encode(bio.getvalue()) - return f'{str(val)}' + def _viz_pil_image(val: pil_image.Image) -> str: + img = val.copy() + img.thumbnail((128, 128)) + bio = io.BytesIO() + img.save(bio, "JPEG") + base64_img = base64.b64encode(bio.getvalue()) + return f'{str(val)}' - register_viz_hook(PIL.Image.Image, _viz_pil_image) + register_viz_hook(pil_image.Image, _viz_pil_image) + _PILLOW_REGISTERED = True -if HAS_NUMPY: - - def _viz_numpy(val: np.ndarray) -> str: - return f"<np.ndarray
shape={val.shape}
dtype={val.dtype}>" - - register_viz_hook(np.ndarray, _viz_numpy) + for klass in _VIZ_HOOKS_REGISTRY: + if isinstance(val, klass): + return _VIZ_HOOKS_REGISTRY[klass] + return None diff --git a/tests/dataframe/test_logical_type.py b/tests/dataframe/test_logical_type.py index 946dd5906d..2d598739a4 100644 --- a/tests/dataframe/test_logical_type.py +++ b/tests/dataframe/test_logical_type.py @@ -8,10 +8,11 @@ import daft from daft import DataType, Series, col -from daft.datatype import DaftExtension +from daft.datatype import get_super_ext_type from daft.utils import pyarrow_supports_fixed_shape_tensor ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) +DaftExtension = get_super_ext_type() def test_embedding_type_df() -> None: diff --git a/tests/series/test_embedding.py b/tests/series/test_embedding.py index 3181fe7e78..c4fb36ddde 100644 --- a/tests/series/test_embedding.py +++ b/tests/series/test_embedding.py @@ -5,9 +5,11 @@ import numpy as np import pandas as pd -from daft.datatype import DaftExtension, DataType +from daft.datatype import DataType, get_super_ext_type from daft.series import Series +DaftExtension = get_super_ext_type() + def test_embedding_arrow_round_trip(): data = [[1, 2, 3], np.arange(3), ["1", "2", "3"], [1, "2", 3.0], pd.Series([1.1, 2, 3]), (1, 2, 3), None] diff --git a/tests/series/test_image.py b/tests/series/test_image.py index a86dc49818..bdd7c1f0ea 100644 --- a/tests/series/test_image.py +++ b/tests/series/test_image.py @@ -9,9 +9,11 @@ import pytest from PIL import Image, ImageSequence -from daft.datatype import DaftExtension, DataType +from daft.datatype import DataType, get_super_ext_type from daft.series import Series +DaftExtension = get_super_ext_type() + MODE_TO_NP_DTYPE = { "L": np.uint8, "LA": np.uint8, diff --git a/tests/series/test_tensor.py b/tests/series/test_tensor.py index fbea2ba64a..5877295a58 100644 --- a/tests/series/test_tensor.py +++ b/tests/series/test_tensor.py @@ -6,13 +6,14 @@ import pyarrow as pa import pytest -from daft.datatype import DaftExtension, DataType +from daft.datatype import DataType, get_super_ext_type from daft.series import Series from daft.utils import pyarrow_supports_fixed_shape_tensor from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES from tests.utils import ANSI_ESCAPE ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) +DaftExtension = get_super_ext_type() @pytest.mark.parametrize("dtype", ARROW_INT_TYPES + ARROW_FLOAT_TYPES) diff --git a/tests/table/table_io/test_csv.py b/tests/table/table_io/test_csv.py index dd8291053a..fa8e96fbee 100644 --- a/tests/table/table_io/test_csv.py +++ b/tests/table/table_io/test_csv.py @@ -13,8 +13,8 @@ from daft.datatype import DataType from daft.logical.schema import Schema from daft.runners.partitioning import TableParseCSVOptions, TableReadOptions -from daft.series import ARROW_VERSION from daft.table import MicroPartition, schema_inference, table_io +from daft.utils import get_arrow_version def storage_config_from_use_native_downloader(use_native_downloader: bool) -> StorageConfig: @@ -351,7 +351,8 @@ def test_csv_read_data_custom_comment(use_native_downloader): } ) # Skipping test for arrow < 7.0.0 as comments are not supported in pyarrow - if ARROW_VERSION >= (7, 0, 0): + arrow_version = get_arrow_version() + if arrow_version >= (7, 0, 0): table = table_io.read_csv( file, schema, From 53dec06c4b3c3bfe38baa8e790d15d5daf324dea Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Fri, 20 Sep 2024 14:19:26 -0500 Subject: [PATCH 03/35] [CHORE]: move list functions from daft-dsl to daft-functions (#2854) prerequisite to adding list functions to sql --- daft/daft/__init__.pyi | 26 ++-- daft/expressions/expressions.py | 22 +-- src/daft-dsl/src/functions/list/chunk.rs | 52 ------- src/daft-dsl/src/functions/list/count.rs | 60 -------- src/daft-dsl/src/functions/list/explode.rs | 36 ----- src/daft-dsl/src/functions/list/max.rs | 44 ------ src/daft-dsl/src/functions/list/mean.rs | 39 ----- src/daft-dsl/src/functions/list/mod.rs | 139 ------------------ src/daft-dsl/src/functions/list/sum.rs | 40 ----- src/daft-dsl/src/functions/mod.rs | 7 +- src/daft-dsl/src/python.rs | 56 +------ src/daft-functions/src/lib.rs | 2 + src/daft-functions/src/list/chunk.rs | 69 +++++++++ src/daft-functions/src/list/count.rs | 75 ++++++++++ src/daft-functions/src/list/explode.rs | 64 ++++++++ .../src}/list/get.rs | 46 ++++-- .../src}/list/join.rs | 46 ++++-- src/daft-functions/src/list/max.rs | 72 +++++++++ src/daft-functions/src/list/mean.rs | 68 +++++++++ .../src}/list/min.rs | 41 +++++- src/daft-functions/src/list/mod.rs | 40 +++++ .../src}/list/slice.rs | 50 +++++-- src/daft-functions/src/list/sum.rs | 72 +++++++++ src/daft-plan/Cargo.toml | 2 +- src/daft-plan/src/logical_ops/explode.rs | 2 +- src/daft-sql/src/planner.rs | 4 +- src/daft-table/src/ops/explode.rs | 34 ++--- 27 files changed, 658 insertions(+), 550 deletions(-) delete mode 100644 src/daft-dsl/src/functions/list/chunk.rs delete mode 100644 src/daft-dsl/src/functions/list/count.rs delete mode 100644 src/daft-dsl/src/functions/list/explode.rs delete mode 100644 src/daft-dsl/src/functions/list/max.rs delete mode 100644 src/daft-dsl/src/functions/list/mean.rs delete mode 100644 src/daft-dsl/src/functions/list/mod.rs delete mode 100644 src/daft-dsl/src/functions/list/sum.rs create mode 100644 src/daft-functions/src/list/chunk.rs create mode 100644 src/daft-functions/src/list/count.rs create mode 100644 src/daft-functions/src/list/explode.rs rename src/{daft-dsl/src/functions => daft-functions/src}/list/get.rs (52%) rename src/{daft-dsl/src/functions => daft-functions/src}/list/join.rs (64%) create mode 100644 src/daft-functions/src/list/max.rs create mode 100644 src/daft-functions/src/list/mean.rs rename src/{daft-dsl/src/functions => daft-functions/src}/list/min.rs (50%) create mode 100644 src/daft-functions/src/list/mod.rs rename src/{daft-dsl/src/functions => daft-functions/src}/list/slice.rs (51%) create mode 100644 src/daft-functions/src/list/sum.rs diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 98f91fadb7..b08a0633e4 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1073,7 +1073,6 @@ class PyExpr: def any_value(self, ignore_nulls: bool) -> PyExpr: ... def agg_list(self) -> PyExpr: ... def agg_concat(self) -> PyExpr: ... - def explode(self) -> PyExpr: ... def __abs__(self) -> PyExpr: ... def __add__(self, other: PyExpr) -> PyExpr: ... def __sub__(self, other: PyExpr) -> PyExpr: ... @@ -1142,15 +1141,6 @@ class PyExpr: def utf8_to_date(self, format: str) -> PyExpr: ... def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PyExpr: ... def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ... - def list_join(self, delimiter: PyExpr) -> PyExpr: ... - def list_count(self, mode: CountMode) -> PyExpr: ... - def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ... - def list_sum(self) -> PyExpr: ... - def list_mean(self) -> PyExpr: ... - def list_min(self) -> PyExpr: ... - def list_max(self) -> PyExpr: ... - def list_slice(self, start: PyExpr, end: PyExpr | None = None) -> PyExpr: ... - def list_chunk(self, size: int) -> PyExpr: ... def struct_get(self, name: str) -> PyExpr: ... def map_get(self, key: PyExpr) -> PyExpr: ... def partitioning_days(self) -> PyExpr: ... @@ -1236,7 +1226,6 @@ def minhash( def sql(sql: str, catalog: PyCatalog, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def sql_expr(sql: str) -> PyExpr: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... -def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ... def cbrt(expr: PyExpr) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... @@ -1262,6 +1251,21 @@ def fill_nan(expr: PyExpr, fill_value: PyExpr) -> PyExpr: ... # --- def json_query(expr: PyExpr, query: str) -> PyExpr: ... +# --- +# expr.list namespace +# --- +def explode(expr: PyExpr) -> PyExpr: ... +def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ... +def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ... +def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ... +def list_get(expr: PyExpr, idx: PyExpr, default: PyExpr) -> PyExpr: ... +def list_sum(expr: PyExpr) -> PyExpr: ... +def list_mean(expr: PyExpr) -> PyExpr: ... +def list_min(expr: PyExpr) -> PyExpr: ... +def list_max(expr: PyExpr) -> PyExpr: ... +def list_slice(expr: PyExpr, start: PyExpr, end: PyExpr | None = None) -> PyExpr: ... +def list_chunk(expr: PyExpr, size: int) -> PyExpr: ... + class PyCatalog: @staticmethod def new() -> PyCatalog: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index ede945f586..4ecc99108e 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -884,7 +884,7 @@ def agg_concat(self) -> Expression: return Expression._from_pyexpr(expr) def _explode(self) -> Expression: - expr = self._expr.explode() + expr = native.explode(self._expr) return Expression._from_pyexpr(expr) def if_else(self, if_true: Expression, if_false: Expression) -> Expression: @@ -2919,7 +2919,7 @@ def join(self, delimiter: str | Expression) -> Expression: Expression: a String expression which is every element of the list joined on the delimiter """ delimiter_expr = Expression._to_expression(delimiter) - return Expression._from_pyexpr(self._expr.list_join(delimiter_expr._expr)) + return Expression._from_pyexpr(native.list_join(self._expr, delimiter_expr._expr)) def count(self, mode: CountMode = CountMode.Valid) -> Expression: """Counts the number of elements in each list @@ -2930,7 +2930,7 @@ def count(self, mode: CountMode = CountMode.Valid) -> Expression: Returns: Expression: a UInt64 expression which is the length of each list """ - return Expression._from_pyexpr(self._expr.list_count(mode)) + return Expression._from_pyexpr(native.list_count(self._expr, mode)) def lengths(self) -> Expression: """Gets the length of each list @@ -2938,7 +2938,7 @@ def lengths(self) -> Expression: Returns: Expression: a UInt64 expression which is the length of each list """ - return Expression._from_pyexpr(self._expr.list_count(CountMode.All)) + return Expression._from_pyexpr(native.list_count(self._expr, CountMode.All)) def get(self, idx: int | Expression, default: object = None) -> Expression: """Gets the element at an index in each list @@ -2952,7 +2952,7 @@ def get(self, idx: int | Expression, default: object = None) -> Expression: """ idx_expr = Expression._to_expression(idx) default_expr = lit(default) - return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr)) + return Expression._from_pyexpr(native.list_get(self._expr, idx_expr._expr, default_expr._expr)) def slice(self, start: int | Expression, end: int | Expression | None = None) -> Expression: """Gets a subset of each list @@ -2966,7 +2966,7 @@ def slice(self, start: int | Expression, end: int | Expression | None = None) -> """ start_expr = Expression._to_expression(start) end_expr = Expression._to_expression(end) - return Expression._from_pyexpr(self._expr.list_slice(start_expr._expr, end_expr._expr)) + return Expression._from_pyexpr(native.list_slice(self._expr, start_expr._expr, end_expr._expr)) def chunk(self, size: int) -> Expression: """Splits each list into chunks of the given size @@ -2978,7 +2978,7 @@ def chunk(self, size: int) -> Expression: """ if not (isinstance(size, int) and size > 0): raise ValueError(f"Invalid value for `size`: {size}") - return Expression._from_pyexpr(self._expr.list_chunk(size)) + return Expression._from_pyexpr(native.list_chunk(self._expr, size)) def sum(self) -> Expression: """Sums each list. Empty lists and lists with all nulls yield null. @@ -2986,7 +2986,7 @@ def sum(self) -> Expression: Returns: Expression: an expression with the type of the list values """ - return Expression._from_pyexpr(self._expr.list_sum()) + return Expression._from_pyexpr(native.list_sum(self._expr)) def mean(self) -> Expression: """Calculates the mean of each list. If no non-null values in a list, the result is null. @@ -2994,7 +2994,7 @@ def mean(self) -> Expression: Returns: Expression: a Float64 expression with the type of the list values """ - return Expression._from_pyexpr(self._expr.list_mean()) + return Expression._from_pyexpr(native.list_mean(self._expr)) def min(self) -> Expression: """Calculates the minimum of each list. If no non-null values in a list, the result is null. @@ -3002,7 +3002,7 @@ def min(self) -> Expression: Returns: Expression: a Float64 expression with the type of the list values """ - return Expression._from_pyexpr(self._expr.list_min()) + return Expression._from_pyexpr(native.list_min(self._expr)) def max(self) -> Expression: """Calculates the maximum of each list. If no non-null values in a list, the result is null. @@ -3010,7 +3010,7 @@ def max(self) -> Expression: Returns: Expression: a Float64 expression with the type of the list values """ - return Expression._from_pyexpr(self._expr.list_max()) + return Expression._from_pyexpr(native.list_max(self._expr)) def sort(self, desc: bool | Expression = False) -> Expression: """Sorts the inner lists of a list column. diff --git a/src/daft-dsl/src/functions/list/chunk.rs b/src/daft-dsl/src/functions/list/chunk.rs deleted file mode 100644 index 60f4567828..0000000000 --- a/src/daft-dsl/src/functions/list/chunk.rs +++ /dev/null @@ -1,52 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, ListExpr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ChunkEvaluator {} - -impl FunctionEvaluator for ChunkEvaluator { - fn fn_name(&self) -> &'static str { - "chunk" - } - - fn to_field( - &self, - inputs: &[ExprRef], - schema: &Schema, - expr: &FunctionExpr, - ) -> DaftResult { - let size = match expr { - FunctionExpr::List(ListExpr::Chunk(size)) => size, - _ => panic!("Expected Chunk Expr, got {expr}"), - }; - match inputs { - [input] => { - let input_field = input.to_field(schema)?; - Ok(input_field - .to_exploded_field()? - .to_fixed_size_list_field(*size)? - .to_list_field()?) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - let size = match expr { - FunctionExpr::List(ListExpr::Chunk(size)) => size, - _ => panic!("Expected Chunk Expr, got {expr}"), - }; - match inputs { - [input] => input.list_chunk(*size), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/list/count.rs b/src/daft-dsl/src/functions/list/count.rs deleted file mode 100644 index 8392e7353f..0000000000 --- a/src/daft-dsl/src/functions/list/count.rs +++ /dev/null @@ -1,60 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, ListExpr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct CountEvaluator {} - -impl FunctionEvaluator for CountEvaluator { - fn fn_name(&self) -> &'static str { - "count" - } - - fn to_field( - &self, - inputs: &[ExprRef], - schema: &Schema, - expr: &FunctionExpr, - ) -> DaftResult { - match inputs { - [input] => { - let input_field = input.to_field(schema)?; - - match input_field.dtype { - DataType::List(_) | DataType::FixedSizeList(_, _) => match expr { - FunctionExpr::List(ListExpr::Count(_)) => { - Ok(Field::new(input.name(), DataType::UInt64)) - } - _ => panic!("Expected List Count Expr, got {expr}"), - }, - _ => Err(DaftError::TypeError(format!( - "Expected input to be a list type, received: {}", - input_field.dtype - ))), - } - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [input] => { - let mode = match expr { - FunctionExpr::List(ListExpr::Count(mode)) => mode, - _ => panic!("Expected List Count Expr, got {expr}"), - }; - - Ok(input.list_count(*mode)?.into_series()) - } - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/list/explode.rs b/src/daft-dsl/src/functions/list/explode.rs deleted file mode 100644 index 6065ec4486..0000000000 --- a/src/daft-dsl/src/functions/list/explode.rs +++ /dev/null @@ -1,36 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ExplodeEvaluator {} - -impl FunctionEvaluator for ExplodeEvaluator { - fn fn_name(&self) -> &'static str { - "explode" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => { - let field = input.to_field(schema)?; - field.to_exploded_field() - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.explode(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/list/max.rs b/src/daft-dsl/src/functions/list/max.rs deleted file mode 100644 index f77031a25f..0000000000 --- a/src/daft-dsl/src/functions/list/max.rs +++ /dev/null @@ -1,44 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct MaxEvaluator {} - -impl FunctionEvaluator for MaxEvaluator { - fn fn_name(&self) -> &'static str { - "max" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => { - let field = input.to_field(schema)?.to_exploded_field()?; - - if field.dtype.is_numeric() { - Ok(field) - } else { - Err(DaftError::TypeError(format!( - "Expected input to be numeric, got {}", - field.dtype - ))) - } - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => Ok(input.list_max()?), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/list/mean.rs b/src/daft-dsl/src/functions/list/mean.rs deleted file mode 100644 index 9d7a64e050..0000000000 --- a/src/daft-dsl/src/functions/list/mean.rs +++ /dev/null @@ -1,39 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::{datatypes::try_mean_supertype, prelude::*}; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct MeanEvaluator {} - -impl FunctionEvaluator for MeanEvaluator { - fn fn_name(&self) -> &'static str { - "mean" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => { - let inner_field = input.to_field(schema)?.to_exploded_field()?; - Ok(Field::new( - inner_field.name.as_str(), - try_mean_supertype(&inner_field.dtype)?, - )) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => Ok(input.list_mean()?), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/list/mod.rs b/src/daft-dsl/src/functions/list/mod.rs deleted file mode 100644 index 4e1cdc8a9e..0000000000 --- a/src/daft-dsl/src/functions/list/mod.rs +++ /dev/null @@ -1,139 +0,0 @@ -mod chunk; -mod count; -mod explode; -mod get; -mod join; -mod max; -mod mean; -mod min; -mod slice; -mod sum; - -use chunk::ChunkEvaluator; -use count::CountEvaluator; -use daft_core::count_mode::CountMode; -use explode::ExplodeEvaluator; -use get::GetEvaluator; -use join::JoinEvaluator; -use max::MaxEvaluator; -use mean::MeanEvaluator; -use min::MinEvaluator; -use serde::{Deserialize, Serialize}; -use slice::SliceEvaluator; -use sum::SumEvaluator; - -use super::FunctionEvaluator; -use crate::{Expr, ExprRef}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum ListExpr { - Explode, - Join, - Count(CountMode), - Get, - Sum, - Mean, - Min, - Max, - Slice, - Chunk(usize), -} - -impl ListExpr { - #[inline] - pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use ListExpr::*; - match self { - Explode => &ExplodeEvaluator {}, - Join => &JoinEvaluator {}, - Count(_) => &CountEvaluator {}, - Get => &GetEvaluator {}, - Sum => &SumEvaluator {}, - Mean => &MeanEvaluator {}, - Min => &MinEvaluator {}, - Max => &MaxEvaluator {}, - Slice => &SliceEvaluator {}, - Chunk(_) => &ChunkEvaluator {}, - } - } -} - -pub fn explode(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Explode), - inputs: vec![input], - } - .into() -} - -pub fn join(input: ExprRef, delimiter: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Join), - inputs: vec![input, delimiter], - } - .into() -} - -pub fn count(input: ExprRef, mode: CountMode) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Count(mode)), - inputs: vec![input], - } - .into() -} - -pub fn get(input: ExprRef, idx: ExprRef, default: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Get), - inputs: vec![input, idx, default], - } - .into() -} - -pub fn sum(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Sum), - inputs: vec![input], - } - .into() -} - -pub fn mean(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Mean), - inputs: vec![input], - } - .into() -} - -pub fn min(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Min), - inputs: vec![input], - } - .into() -} - -pub fn max(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Max), - inputs: vec![input], - } - .into() -} - -pub fn slice(input: ExprRef, start: ExprRef, end: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Slice), - inputs: vec![input, start, end], - } - .into() -} - -pub fn chunk(input: ExprRef, size: usize) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::List(ListExpr::Chunk(size)), - inputs: vec![input], - } - .into() -} diff --git a/src/daft-dsl/src/functions/list/sum.rs b/src/daft-dsl/src/functions/list/sum.rs deleted file mode 100644 index 42710fa85f..0000000000 --- a/src/daft-dsl/src/functions/list/sum.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::{datatypes::try_sum_supertype, prelude::*}; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct SumEvaluator {} - -impl FunctionEvaluator for SumEvaluator { - fn fn_name(&self) -> &'static str { - "sum" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => { - let inner_field = input.to_field(schema)?.to_exploded_field()?; - - Ok(Field::new( - inner_field.name.as_str(), - try_sum_supertype(&inner_field.dtype)?, - )) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => Ok(input.list_sum()?), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index f0f64e0f63..216c79a2f9 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -1,4 +1,3 @@ -pub mod list; pub mod map; pub mod numeric; pub mod partitioning; @@ -19,8 +18,8 @@ pub use scalar::*; use serde::{Deserialize, Serialize}; use self::{ - list::ListExpr, map::MapExpr, numeric::NumericExpr, partitioning::PartitioningExpr, - sketch::SketchExpr, struct_::StructExpr, temporal::TemporalExpr, utf8::Utf8Expr, + map::MapExpr, numeric::NumericExpr, partitioning::PartitioningExpr, sketch::SketchExpr, + struct_::StructExpr, temporal::TemporalExpr, utf8::Utf8Expr, }; use crate::{Expr, ExprRef, Operator}; @@ -32,7 +31,6 @@ pub enum FunctionExpr { Numeric(NumericExpr), Utf8(Utf8Expr), Temporal(TemporalExpr), - List(ListExpr), Map(MapExpr), Sketch(SketchExpr), Struct(StructExpr), @@ -59,7 +57,6 @@ impl FunctionExpr { Numeric(expr) => expr.get_evaluator(), Utf8(expr) => expr.get_evaluator(), Temporal(expr) => expr.get_evaluator(), - List(expr) => expr.get_evaluator(), Map(expr) => expr.get_evaluator(), Sketch(expr) => expr.get_evaluator(), Struct(expr) => expr.get_evaluator(), diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 3aacabb6f1..f9ebbce6c3 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -1,3 +1,5 @@ +#![allow(non_snake_case)] + use std::{ collections::{hash_map::DefaultHasher, HashMap}, hash::{Hash, Hasher}, @@ -458,28 +460,19 @@ impl PyExpr { Ok(self.expr.clone().agg_concat().into()) } - pub fn explode(&self) -> PyResult { - use functions::list::explode; - Ok(explode(self.into()).into()) - } - pub fn __abs__(&self) -> PyResult { use functions::numeric::abs; Ok(abs(self.into()).into()) } - pub fn __add__(&self, other: &Self) -> PyResult { Ok(crate::binary_op(crate::Operator::Plus, self.into(), other.expr.clone()).into()) } - pub fn __sub__(&self, other: &Self) -> PyResult { Ok(crate::binary_op(crate::Operator::Minus, self.into(), other.expr.clone()).into()) } - pub fn __mul__(&self, other: &Self) -> PyResult { Ok(crate::binary_op(crate::Operator::Multiply, self.into(), other.expr.clone()).into()) } - pub fn __floordiv__(&self, other: &Self) -> PyResult { Ok(crate::binary_op( crate::Operator::FloorDivide, @@ -788,51 +781,6 @@ impl PyExpr { Ok(normalize(self.into(), opts).into()) } - pub fn list_join(&self, delimiter: &Self) -> PyResult { - use crate::functions::list::join; - Ok(join(self.into(), delimiter.into()).into()) - } - - pub fn list_count(&self, mode: CountMode) -> PyResult { - use crate::functions::list::count; - Ok(count(self.into(), mode).into()) - } - - pub fn list_get(&self, idx: &Self, default: &Self) -> PyResult { - use crate::functions::list::get; - Ok(get(self.into(), idx.into(), default.into()).into()) - } - - pub fn list_sum(&self) -> PyResult { - use crate::functions::list::sum; - Ok(sum(self.into()).into()) - } - - pub fn list_mean(&self) -> PyResult { - use crate::functions::list::mean; - Ok(mean(self.into()).into()) - } - - pub fn list_min(&self) -> PyResult { - use crate::functions::list::min; - Ok(min(self.into()).into()) - } - - pub fn list_max(&self) -> PyResult { - use crate::functions::list::max; - Ok(max(self.into()).into()) - } - - pub fn list_slice(&self, start: &Self, end: &Self) -> PyResult { - use crate::functions::list::slice; - Ok(slice(self.into(), start.into(), end.into()).into()) - } - - pub fn list_chunk(&self, size: usize) -> PyResult { - use crate::functions::list::chunk; - Ok(chunk(self.into(), size).into()) - } - pub fn struct_get(&self, name: &str) -> PyResult { use crate::functions::struct_::get; Ok(get(self.into(), name).into()) diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index d55d17a5f8..f35edde44a 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -4,6 +4,7 @@ pub mod distance; pub mod float; pub mod hash; pub mod image; +pub mod list; pub mod list_sort; pub mod minhash; pub mod numeric; @@ -50,6 +51,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(uri::python::url_upload, parent)?)?; image::register_modules(parent)?; float::register_modules(parent)?; + list::register_modules(parent)?; Ok(()) } diff --git a/src/daft-functions/src/list/chunk.rs b/src/daft-functions/src/list/chunk.rs new file mode 100644 index 0000000000..39743e80b9 --- /dev/null +++ b/src/daft-functions/src/list/chunk.rs @@ -0,0 +1,69 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListChunk { + pub size: usize, +} + +#[typetag::serde] +impl ScalarUDF for ListChunk { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "chunk" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] => { + let input_field = input.to_field(schema)?; + Ok(input_field + .to_exploded_field()? + .to_fixed_size_list_field(self.size)? + .to_list_field()?) + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input] => input.list_chunk(self.size), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +pub fn list_chunk(expr: ExprRef, size: usize) -> ExprRef { + ScalarFunction::new(ListChunk { size }, vec![expr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_chunk")] +pub fn py_list_chunk(expr: PyExpr, size: usize) -> PyResult { + Ok(list_chunk(expr.into(), size).into()) +} diff --git a/src/daft-functions/src/list/count.rs b/src/daft-functions/src/list/count.rs new file mode 100644 index 0000000000..d00600163c --- /dev/null +++ b/src/daft-functions/src/list/count.rs @@ -0,0 +1,75 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{CountMode, DataType, Field, Schema}, + series::{IntoSeries, Series}, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListCount { + mode: CountMode, +} + +#[typetag::serde] +impl ScalarUDF for ListCount { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "count" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] => { + let input_field = input.to_field(schema)?; + + match input_field.dtype { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + Ok(Field::new(input.name(), DataType::UInt64)) + } + _ => Err(DaftError::TypeError(format!( + "Expected input to be a list type, received: {}", + input_field.dtype + ))), + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input] => Ok(input.list_count(self.mode)?.into_series()), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} + +pub fn list_count(expr: ExprRef, mode: CountMode) -> ExprRef { + ScalarFunction::new(ListCount { mode }, vec![expr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_count")] +pub fn py_list_count(expr: PyExpr, mode: CountMode) -> PyResult { + Ok(list_count(expr.into(), mode).into()) +} diff --git a/src/daft-functions/src/list/explode.rs b/src/daft-functions/src/list/explode.rs new file mode 100644 index 0000000000..a2232b33f9 --- /dev/null +++ b/src/daft-functions/src/list/explode.rs @@ -0,0 +1,64 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Explode {} + +#[typetag::serde] +impl ScalarUDF for Explode { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "explode" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?; + field.to_exploded_field() + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input] => input.explode(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} + +pub fn explode(expr: ExprRef) -> ExprRef { + ScalarFunction::new(Explode {}, vec![expr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "explode")] +pub fn py_explode(expr: PyExpr) -> PyResult { + Ok(explode(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/list/get.rs b/src/daft-functions/src/list/get.rs similarity index 52% rename from src/daft-dsl/src/functions/list/get.rs rename to src/daft-functions/src/list/get.rs index 3325603e13..15f088ce0c 100644 --- a/src/daft-dsl/src/functions/list/get.rs +++ b/src/daft-functions/src/list/get.rs @@ -1,17 +1,28 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListGet {} -pub(super) struct GetEvaluator {} +#[typetag::serde] +impl ScalarUDF for ListGet { + fn as_any(&self) -> &dyn std::any::Any { + self + } -impl FunctionEvaluator for GetEvaluator { - fn fn_name(&self) -> &'static str { - "get" + fn name(&self) -> &'static str { + "list_get" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [input, idx, default] => { let input_field = input.to_field(schema)?; @@ -37,7 +48,7 @@ impl FunctionEvaluator for GetEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [input, idx, default] => Ok(input.list_get(idx, default)?), _ => Err(DaftError::ValueError(format!( @@ -47,3 +58,20 @@ impl FunctionEvaluator for GetEvaluator { } } } + +pub fn list_get(expr: ExprRef, idx: ExprRef, default_value: ExprRef) -> ExprRef { + ScalarFunction::new(ListGet {}, vec![expr, idx, default_value]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_get")] +pub fn py_list_get(expr: PyExpr, idx: PyExpr, default_value: PyExpr) -> PyResult { + Ok(list_get(expr.into(), idx.into(), default_value.into()).into()) +} diff --git a/src/daft-dsl/src/functions/list/join.rs b/src/daft-functions/src/list/join.rs similarity index 64% rename from src/daft-dsl/src/functions/list/join.rs rename to src/daft-functions/src/list/join.rs index af51ef886f..83d2f87efb 100644 --- a/src/daft-dsl/src/functions/list/join.rs +++ b/src/daft-functions/src/list/join.rs @@ -1,17 +1,28 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::{IntoSeries, Series}, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListJoin {} -pub(super) struct JoinEvaluator {} +#[typetag::serde] +impl ScalarUDF for ListJoin { + fn as_any(&self) -> &dyn std::any::Any { + self + } -impl FunctionEvaluator for JoinEvaluator { - fn fn_name(&self) -> &'static str { - "join" + fn name(&self) -> &'static str { + "list_join" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [input, delimiter] => { let input_field = input.to_field(schema)?; @@ -45,7 +56,7 @@ impl FunctionEvaluator for JoinEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [input, delimiter] => { let delimiter = delimiter.utf8().unwrap(); @@ -58,3 +69,20 @@ impl FunctionEvaluator for JoinEvaluator { } } } + +pub fn list_join(expr: ExprRef, delim: ExprRef) -> ExprRef { + ScalarFunction::new(ListJoin {}, vec![expr, delim]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_join")] +pub fn py_list_join(expr: PyExpr, delim: PyExpr) -> PyResult { + Ok(list_join(expr.into(), delim.into()).into()) +} diff --git a/src/daft-functions/src/list/max.rs b/src/daft-functions/src/list/max.rs new file mode 100644 index 0000000000..22621eb7f9 --- /dev/null +++ b/src/daft-functions/src/list/max.rs @@ -0,0 +1,72 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListMax {} + +#[typetag::serde] +impl ScalarUDF for ListMax { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "list_max" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?.to_exploded_field()?; + + if field.dtype.is_numeric() { + Ok(field) + } else { + Err(DaftError::TypeError(format!( + "Expected input to be numeric, got {}", + field.dtype + ))) + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input] => Ok(input.list_max()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} + +pub fn list_max(expr: ExprRef) -> ExprRef { + ScalarFunction::new(ListMax {}, vec![expr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_max")] +pub fn py_list_max(expr: PyExpr) -> PyResult { + Ok(list_max(expr.into()).into()) +} diff --git a/src/daft-functions/src/list/mean.rs b/src/daft-functions/src/list/mean.rs new file mode 100644 index 0000000000..16a817a9c3 --- /dev/null +++ b/src/daft-functions/src/list/mean.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + datatypes::try_mean_supertype, + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListMean {} + +#[typetag::serde] +impl ScalarUDF for ListMean { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "list_mean" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] => { + let inner_field = input.to_field(schema)?.to_exploded_field()?; + Ok(Field::new( + inner_field.name.as_str(), + try_mean_supertype(&inner_field.dtype)?, + )) + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input] => Ok(input.list_mean()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} + +pub fn list_mean(expr: ExprRef) -> ExprRef { + ScalarFunction::new(ListMean {}, vec![expr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_mean")] +pub fn py_list_mean(expr: PyExpr) -> PyResult { + Ok(list_mean(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/list/min.rs b/src/daft-functions/src/list/min.rs similarity index 50% rename from src/daft-dsl/src/functions/list/min.rs rename to src/daft-functions/src/list/min.rs index 14a073ab6b..8386b38410 100644 --- a/src/daft-dsl/src/functions/list/min.rs +++ b/src/daft-functions/src/list/min.rs @@ -1,17 +1,25 @@ use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListMin {} -pub(super) struct MinEvaluator {} +#[typetag::serde] +impl ScalarUDF for ListMin { + fn as_any(&self) -> &dyn std::any::Any { + self + } -impl FunctionEvaluator for MinEvaluator { - fn fn_name(&self) -> &'static str { - "min" + fn name(&self) -> &'static str { + "list_min" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [input] => { let field = input.to_field(schema)?.to_exploded_field()?; @@ -32,7 +40,7 @@ impl FunctionEvaluator for MinEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [input] => Ok(input.list_min()?), _ => Err(DaftError::ValueError(format!( @@ -42,3 +50,20 @@ impl FunctionEvaluator for MinEvaluator { } } } + +pub fn list_min(expr: ExprRef) -> ExprRef { + ScalarFunction::new(ListMin {}, vec![expr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_min")] +pub fn py_list_min(expr: PyExpr) -> PyResult { + Ok(list_min(expr.into()).into()) +} diff --git a/src/daft-functions/src/list/mod.rs b/src/daft-functions/src/list/mod.rs new file mode 100644 index 0000000000..fd8007b4b5 --- /dev/null +++ b/src/daft-functions/src/list/mod.rs @@ -0,0 +1,40 @@ +mod chunk; +mod count; +mod explode; +mod get; +mod join; +mod max; +mod mean; +mod min; +mod slice; +mod sum; +pub use chunk::list_chunk as chunk; +pub use count::list_count as count; +pub use explode::explode; +pub use get::list_get as get; +pub use join::list_join as join; +pub use max::list_max as max; +pub use mean::list_mean as mean; +pub use min::list_min as min; +#[cfg(feature = "python")] +use pyo3::prelude::*; +pub use slice::list_slice as slice; +pub use sum::list_sum as sum; + +#[cfg(feature = "python")] +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_function(wrap_pyfunction_bound!(explode::py_explode, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(chunk::py_list_chunk, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(count::py_list_count, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(get::py_list_get, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(join::py_list_join, parent)?)?; + + parent.add_function(wrap_pyfunction_bound!(max::py_list_max, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(min::py_list_min, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(mean::py_list_mean, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(min::py_list_min, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(slice::py_list_slice, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(sum::py_list_sum, parent)?)?; + + Ok(()) +} diff --git a/src/daft-dsl/src/functions/list/slice.rs b/src/daft-functions/src/list/slice.rs similarity index 51% rename from src/daft-dsl/src/functions/list/slice.rs rename to src/daft-functions/src/list/slice.rs index e56dd88aa9..f62e47474d 100644 --- a/src/daft-dsl/src/functions/list/slice.rs +++ b/src/daft-functions/src/list/slice.rs @@ -1,17 +1,28 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListSlice {} -pub(super) struct SliceEvaluator {} +#[typetag::serde] +impl ScalarUDF for ListSlice { + fn as_any(&self) -> &dyn std::any::Any { + self + } -impl FunctionEvaluator for SliceEvaluator { - fn fn_name(&self) -> &'static str { - "slice" + fn name(&self) -> &'static str { + "list_slice" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [input, start, end] => { let input_field = input.to_field(schema)?; @@ -40,13 +51,30 @@ impl FunctionEvaluator for SliceEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { - [input, start, end] => input.list_slice(start, end), + [input, start, end] => Ok(input.list_slice(start, end)?), _ => Err(DaftError::ValueError(format!( - "Expected 3 input args, got {}", + "Expected 1 input arg, got {}", inputs.len() ))), } } } + +pub fn list_slice(expr: ExprRef, start: ExprRef, end: ExprRef) -> ExprRef { + ScalarFunction::new(ListSlice {}, vec![expr, start, end]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_slice")] +pub fn py_list_slice(expr: PyExpr, start: PyExpr, end: PyExpr) -> PyResult { + Ok(list_slice(expr.into(), start.into(), end.into()).into()) +} diff --git a/src/daft-functions/src/list/sum.rs b/src/daft-functions/src/list/sum.rs new file mode 100644 index 0000000000..82883faf26 --- /dev/null +++ b/src/daft-functions/src/list/sum.rs @@ -0,0 +1,72 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ListSum {} + +#[typetag::serde] +impl ScalarUDF for ListSum { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "list_sum" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?.to_exploded_field()?; + + if field.dtype.is_numeric() { + Ok(field) + } else { + Err(DaftError::TypeError(format!( + "Expected input to be numeric, got {}", + field.dtype + ))) + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input] => Ok(input.list_sum()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} + +pub fn list_sum(expr: ExprRef) -> ExprRef { + ScalarFunction::new(ListSum {}, vec![expr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_sum")] +pub fn py_list_sum(expr: PyExpr) -> PyResult { + Ok(list_sum(expr.into()).into()) +} diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 1c5d224d89..d2cd422dba 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -23,6 +23,7 @@ common-resource-request = {path = "../common/resource-request", default-features common-treenode = {path = "../common/treenode", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} +daft-functions = {path = "../daft-functions", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} daft-schema = {path = "../daft-schema", default-features = false} daft-table = {path = "../daft-table", default-features = false} @@ -35,7 +36,6 @@ snafu = {workspace = true} [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} -daft-functions = {path = "../daft-functions", default-features = false} pretty_assertions = {workspace = true} rstest = {workspace = true} test-log = {workspace = true} diff --git a/src/daft-plan/src/logical_ops/explode.rs b/src/daft-plan/src/logical_ops/explode.rs index f779f58871..66e5ad0a5c 100644 --- a/src/daft-plan/src/logical_ops/explode.rs +++ b/src/daft-plan/src/logical_ops/explode.rs @@ -32,7 +32,7 @@ impl Explode { let explode_exprs = to_explode .iter() .cloned() - .map(daft_dsl::functions::list::explode) + .map(daft_functions::list::explode) .collect::>(); let exploded_schema = { let explode_schema = { diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index b0f95d971c..ff92b38b85 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -702,7 +702,7 @@ impl SQLPlanner { Subscript::Index { index } => { let index = self.plan_expr(index)?; let expr = self.plan_expr(expr)?; - Ok(daft_dsl::functions::list::get(expr, index, null_lit())) + Ok(daft_functions::list::get(expr, index, null_lit())) } Subscript::Slice { lower_bound, @@ -717,7 +717,7 @@ impl SQLPlanner { let lower = self.plan_expr(lower)?; let upper = self.plan_expr(upper)?; let expr = self.plan_expr(expr)?; - Ok(daft_dsl::functions::list::slice(expr, lower, upper)) + Ok(daft_functions::list::slice(expr, lower, upper)) } _ => { unsupported_sql_err!("slice with only one bound not yet supported"); diff --git a/src/daft-table/src/ops/explode.rs b/src/daft-table/src/ops/explode.rs index 85c694c3c2..2c0fc0fee3 100644 --- a/src/daft-table/src/ops/explode.rs +++ b/src/daft-table/src/ops/explode.rs @@ -27,31 +27,29 @@ impl Table { ))); } - use daft_dsl::functions::{list::ListExpr, FunctionExpr}; - let mut evaluated_columns = Vec::with_capacity(exprs.len()); for expr in exprs { match expr.as_ref() { - Expr::Function { - func: FunctionExpr::List(ListExpr::Explode), - inputs, - } => { - if inputs.len() != 1 { - return Err(DaftError::ValueError(format!("ListExpr::Explode function expression must have one input only, received: {}", inputs.len()))); - } - let expr = inputs.first().unwrap(); - let exploded_name = expr.name(); - let evaluated = self.eval_expression(expr)?; - if !matches!( - evaluated.data_type(), - DataType::List(..) | DataType::FixedSizeList(..) - ) { - return Err(DaftError::ValueError(format!( + Expr::ScalarFunction(func) => { + if func.name() == "explode" { + let inputs = &func.inputs; + if inputs.len() != 1 { + return Err(DaftError::ValueError(format!("ListExpr::Explode function expression must have one input only, received: {}", inputs.len()))); + } + let expr = inputs.first().unwrap(); + let exploded_name = expr.name(); + let evaluated = self.eval_expression(expr)?; + if !matches!( + evaluated.data_type(), + DataType::List(..) | DataType::FixedSizeList(..) + ) { + return Err(DaftError::ValueError(format!( "Expected Expression for series: `{exploded_name}` to be a List Type, but is {}", evaluated.data_type() ))); + } + evaluated_columns.push(evaluated); } - evaluated_columns.push(evaluated); } _ => { return Err(DaftError::ValueError( From c481f1b0ad5c842eb5fcb277f79913f6582957f5 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 20 Sep 2024 12:41:57 -0700 Subject: [PATCH 04/35] [BUG] Fix concat expression typing (#2868) https://github.com/Eventual-Inc/Daft/issues/2863 Co-authored-by: Colin Ho --- daft/expressions/expressions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 4ecc99108e..986bad70fd 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1980,7 +1980,7 @@ def split(self, pattern: str | Expression, regex: bool = False) -> Expression: pattern_expr = Expression._to_expression(pattern) return Expression._from_pyexpr(self._expr.utf8_split(pattern_expr._expr, regex)) - def concat(self, other: str) -> Expression: + def concat(self, other: str | Expression) -> Expression: """Concatenates two string expressions together .. NOTE:: @@ -2012,7 +2012,8 @@ def concat(self, other: str) -> Expression: Expression: a String expression which is `self` concatenated with `other` """ # Delegate to + operator implementation. - return Expression._from_pyexpr(self._expr) + other + other_expr = Expression._to_expression(other) + return Expression._from_pyexpr(self._expr) + other_expr def extract(self, pattern: str | Expression, index: int = 0) -> Expression: r"""Extracts the specified match group from the first regex match in each string in a string column. From 688150fb346f091f77454f29502a5f12da2b7462 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Fri, 20 Sep 2024 14:42:13 -0500 Subject: [PATCH 05/35] [FEAT]: `shuffle_join_default_partitions` param (#2844) addresses https://github.com/Eventual-Inc/Daft/issues/2817 --- daft/context.py | 3 + daft/daft/__init__.pyi | 3 + src/common/daft-config/src/lib.rs | 2 + src/common/daft-config/src/python.rs | 12 ++ .../src/physical_planner/translate.rs | 119 +++++++++++++----- tests/dataframe/test_joins.py | 95 +++++++++++++- 6 files changed, 201 insertions(+), 33 deletions(-) diff --git a/daft/context.py b/daft/context.py index 10be0264eb..61b69284af 100644 --- a/daft/context.py +++ b/daft/context.py @@ -308,6 +308,7 @@ def set_execution_config( csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + shuffle_join_default_partitions: int | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -344,6 +345,7 @@ def set_execution_config( csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 shuffle_aggregation_default_partitions: Minimum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. + shuffle_join_default_partitions: Minimum number of partitions to create when performing joins. Defaults to 16, unless the number of input partitions is greater than 16. read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB enable_aqe: Enables Adaptive Query Execution, Defaults to False enable_native_executor: Enables new local executor. Defaults to False @@ -369,6 +371,7 @@ def set_execution_config( csv_target_filesize=csv_target_filesize, csv_inflation_factor=csv_inflation_factor, shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions, + shuffle_join_default_partitions=shuffle_join_default_partitions, read_sql_partition_size_bytes=read_sql_partition_size_bytes, enable_aqe=enable_aqe, enable_native_executor=enable_native_executor, diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index b08a0633e4..83f1eff059 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1753,6 +1753,7 @@ class PyDaftExecutionConfig: csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + shuffle_join_default_partitions: int | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -1785,6 +1786,8 @@ class PyDaftExecutionConfig: @property def shuffle_aggregation_default_partitions(self) -> int: ... @property + def shuffle_join_default_partitions(self) -> int: ... + @property def read_sql_partition_size_bytes(self) -> int: ... @property def enable_aqe(self) -> bool: ... diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index dcaef0a2f8..153d2a80c5 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -52,6 +52,7 @@ pub struct DaftExecutionConfig { pub csv_target_filesize: usize, pub csv_inflation_factor: f64, pub shuffle_aggregation_default_partitions: usize, + pub shuffle_join_default_partitions: usize, pub read_sql_partition_size_bytes: usize, pub enable_aqe: bool, pub enable_native_executor: bool, @@ -75,6 +76,7 @@ impl Default for DaftExecutionConfig { csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, shuffle_aggregation_default_partitions: 200, + shuffle_join_default_partitions: 16, read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB enable_aqe: false, enable_native_executor: false, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 5dda71eda8..818934261a 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -94,6 +94,7 @@ impl PyDaftExecutionConfig { csv_target_filesize: Option, csv_inflation_factor: Option, shuffle_aggregation_default_partitions: Option, + shuffle_join_default_partitions: Option, read_sql_partition_size_bytes: Option, enable_aqe: Option, enable_native_executor: Option, @@ -143,10 +144,16 @@ impl PyDaftExecutionConfig { if let Some(csv_inflation_factor) = csv_inflation_factor { config.csv_inflation_factor = csv_inflation_factor; } + if let Some(shuffle_aggregation_default_partitions) = shuffle_aggregation_default_partitions { config.shuffle_aggregation_default_partitions = shuffle_aggregation_default_partitions; } + + if let Some(shuffle_join_default_partitions) = shuffle_join_default_partitions { + config.shuffle_join_default_partitions = shuffle_join_default_partitions; + } + if let Some(read_sql_partition_size_bytes) = read_sql_partition_size_bytes { config.read_sql_partition_size_bytes = read_sql_partition_size_bytes; } @@ -231,6 +238,11 @@ impl PyDaftExecutionConfig { Ok(self.config.shuffle_aggregation_default_partitions) } + #[getter] + fn get_shuffle_join_default_partitions(&self) -> PyResult { + Ok(self.config.shuffle_join_default_partitions) + } + #[getter] fn get_read_sql_partition_size_bytes(&self) -> PyResult { Ok(self.config.read_sql_partition_size_bytes) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 639c571871..408d4f62a6 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -571,6 +571,7 @@ pub(super) fn translate_single_logical_node( "Sort-merge join currently only supports inner joins".to_string(), )); } + let num_partitions = max(num_partitions, cfg.shuffle_join_default_partitions); let needs_presort = if cfg.sort_merge_join_sort_with_aligned_boundaries { // Use the special-purpose presorting that ensures join inputs are sorted with aligned @@ -616,7 +617,6 @@ pub(super) fn translate_single_logical_node( // allow for leniency in partition size to avoid minor repartitions let num_left_partitions = left_clustering_spec.num_partitions(); let num_right_partitions = right_clustering_spec.num_partitions(); - let num_partitions = match ( is_left_hash_partitioned, is_right_hash_partitioned, @@ -637,6 +637,7 @@ pub(super) fn translate_single_logical_node( } (_, _, a, b) => max(a, b), }; + let num_partitions = max(num_partitions, cfg.shuffle_join_default_partitions); if num_left_partitions != num_partitions || (num_partitions > 1 && !is_left_hash_partitioned) @@ -1076,6 +1077,13 @@ mod tests { Self::Reversed(v) => Self::Reversed(v * x), } } + fn unwrap(&self) -> usize { + match self { + Self::Good(v) => *v, + Self::Bad(v) => *v, + Self::Reversed(v) => *v, + } + } } fn force_repartition( @@ -1128,21 +1136,31 @@ mod tests { fn check_physical_matches( plan: PhysicalPlanRef, + left_partition_size: usize, + right_partition_size: usize, left_repartitions: bool, right_repartitions: bool, + shuffle_join_default_partitions: usize, ) -> bool { match plan.as_ref() { PhysicalPlan::HashJoin(HashJoin { left, right, .. }) => { - let left_works = match (left.as_ref(), left_repartitions) { + let left_works = match ( + left.as_ref(), + left_repartitions || left_partition_size < shuffle_join_default_partitions, + ) { (PhysicalPlan::ReduceMerge(_), true) => true, (PhysicalPlan::Project(_), false) => true, _ => false, }; - let right_works = match (right.as_ref(), right_repartitions) { + let right_works = match ( + right.as_ref(), + right_repartitions || right_partition_size < shuffle_join_default_partitions, + ) { (PhysicalPlan::ReduceMerge(_), true) => true, (PhysicalPlan::Project(_), false) => true, _ => false, }; + left_works && right_works } _ => false, @@ -1152,7 +1170,7 @@ mod tests { /// Tests a variety of settings regarding hash join repartitioning. #[test] fn repartition_hash_join_tests() -> DaftResult<()> { - use RepartitionOptions::*; + use RepartitionOptions::{Bad, Good, Reversed}; let cases = vec![ (Good(30), Good(30), false, false), (Good(30), Good(40), true, false), @@ -1170,9 +1188,17 @@ mod tests { let cfg: Arc = DaftExecutionConfig::default().into(); for (l_opts, r_opts, l_exp, r_exp) in cases { for mult in [1, 10] { - let plan = - get_hash_join_plan(cfg.clone(), l_opts.scale_by(mult), r_opts.scale_by(mult))?; - if !check_physical_matches(plan, l_exp, r_exp) { + let l_opts = l_opts.scale_by(mult); + let r_opts = r_opts.scale_by(mult); + let plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; + if !check_physical_matches( + plan, + l_opts.unwrap(), + r_opts.unwrap(), + l_exp, + r_exp, + cfg.shuffle_join_default_partitions, + ) { panic!( "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", l_opts, r_opts, l_exp, r_exp, mult @@ -1180,9 +1206,15 @@ mod tests { } // reversed direction - let plan = - get_hash_join_plan(cfg.clone(), r_opts.scale_by(mult), l_opts.scale_by(mult))?; - if !check_physical_matches(plan, r_exp, l_exp) { + let plan = get_hash_join_plan(cfg.clone(), r_opts.clone(), l_opts.clone())?; + if !check_physical_matches( + plan, + l_opts.unwrap(), + r_opts.unwrap(), + r_exp, + l_exp, + cfg.shuffle_join_default_partitions, + ) { panic!( "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", r_opts, l_opts, r_exp, l_exp, mult @@ -1199,27 +1231,38 @@ mod tests { let mut cfg = DaftExecutionConfig::default(); cfg.hash_join_partition_size_leniency = 0.8; let cfg = Arc::new(cfg); + let (l_opts, r_opts) = (RepartitionOptions::Good(30), RepartitionOptions::Bad(40)); + let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; + assert!(check_physical_matches( + physical_plan, + l_opts.unwrap(), + r_opts.unwrap(), + true, + true, + cfg.shuffle_join_default_partitions + )); - let physical_plan = get_hash_join_plan( - cfg.clone(), - RepartitionOptions::Good(20), - RepartitionOptions::Bad(40), - )?; - assert!(check_physical_matches(physical_plan, true, true)); - - let physical_plan = get_hash_join_plan( - cfg.clone(), - RepartitionOptions::Good(20), - RepartitionOptions::Bad(25), - )?; - assert!(check_physical_matches(physical_plan, false, true)); + let (l_opts, r_opts) = (RepartitionOptions::Good(20), RepartitionOptions::Bad(25)); + let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; + assert!(check_physical_matches( + physical_plan, + l_opts.unwrap(), + r_opts.unwrap(), + false, + true, + cfg.shuffle_join_default_partitions + )); - let physical_plan = get_hash_join_plan( - cfg.clone(), - RepartitionOptions::Good(20), - RepartitionOptions::Bad(26), - )?; - assert!(check_physical_matches(physical_plan, true, true)); + let (l_opts, r_opts) = (RepartitionOptions::Good(20), RepartitionOptions::Bad(26)); + let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; + assert!(check_physical_matches( + physical_plan, + l_opts.unwrap(), + r_opts.unwrap(), + true, + true, + cfg.shuffle_join_default_partitions + )); Ok(()) } @@ -1237,7 +1280,14 @@ mod tests { let cfg: Arc = DaftExecutionConfig::default().into(); for (l_opts, r_opts, l_exp, r_exp) in cases { let plan = get_hash_join_plan(cfg.clone(), l_opts, r_opts)?; - if !check_physical_matches(plan, l_exp, r_exp) { + if !check_physical_matches( + plan, + l_opts.unwrap(), + r_opts.unwrap(), + l_exp, + r_exp, + cfg.shuffle_join_default_partitions, + ) { panic!( "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", l_opts, r_opts, l_exp, r_exp @@ -1246,7 +1296,14 @@ mod tests { // reversed direction let plan = get_hash_join_plan(cfg.clone(), r_opts, l_opts)?; - if !check_physical_matches(plan, r_exp, l_exp) { + if !check_physical_matches( + plan, + l_opts.unwrap(), + r_opts.unwrap(), + r_exp, + l_exp, + cfg.shuffle_join_default_partitions, + ) { panic!( "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", r_opts, l_opts, r_exp, l_exp diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b0bdbf9df4..8ccc3f72cd 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -3,14 +3,16 @@ import pyarrow as pa import pytest -from daft import col, context +import daft +from daft import col +from daft.context import get_context from daft.datatype import DataType from daft.errors import ExpressionTypeError from tests.utils import sort_arrow_table def skip_invalid_join_strategies(join_strategy, join_type): - if context.get_context().daft_execution_config.enable_native_executor is True: + if get_context().daft_execution_config.enable_native_executor is True: if join_type == "outer" or join_strategy not in [None, "hash"]: pytest.skip("Native executor fails for these tests") else: @@ -1075,3 +1077,92 @@ def test_join_same_name_alias_with_compute(join_strategy, join_type, expected, m assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "a") == sort_arrow_table( pa.Table.from_pydict(expected), "a" ) + + +# the partition size should be the max(shuffle_join_default_partitions, max(left_partition_size, right_partition_size)) +@pytest.mark.parametrize("shuffle_join_default_partitions", [None, 20]) +def test_join_result_partitions_smaller_than_input(shuffle_join_default_partitions): + skip_invalid_join_strategies("hash", "inner") + if shuffle_join_default_partitions is None: + min_partitions = get_context().daft_execution_config.shuffle_join_default_partitions + else: + min_partitions = shuffle_join_default_partitions + + with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): + right_partition_size = 50 + for left_partition_size in [1, min_partitions, min_partitions + 1]: + df_left = daft.from_pydict( + {"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]} + ) + df_left = df_left.into_partitions(left_partition_size) + + df_right = daft.from_pydict( + {"group": [i for i in range(right_partition_size)], "value": [i for i in range(right_partition_size)]} + ) + + df_right = df_right.into_partitions(right_partition_size) + + actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() + n_partitions = actual.num_partitions() + expected_n_partitions = max(min_partitions, left_partition_size, right_partition_size) + assert n_partitions == expected_n_partitions + + +def test_join_right_single_partition(): + skip_invalid_join_strategies("hash", "inner") + shuffle_join_default_partitions = 16 + df_left = daft.from_pydict({"group": [i for i in range(300)], "value": [i for i in range(300)]}).repartition( + 300, "group" + ) + + df_right = daft.from_pydict({"group": [i for i in range(100)], "value": [i for i in range(100)]}).repartition( + 1, "group" + ) + + with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): + actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() + n_partitions = actual.num_partitions() + assert n_partitions == 300 + + +def test_join_right_smaller_than_cfg(): + skip_invalid_join_strategies("hash", "inner") + shuffle_join_default_partitions = 200 + df_left = daft.from_pydict({"group": [i for i in range(199)], "value": [i for i in range(199)]}).repartition( + 199, "group" + ) + + df_right = daft.from_pydict({"group": [i for i in range(100)], "value": [i for i in range(100)]}).repartition( + 100, "group" + ) + + with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): + actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() + n_partitions = actual.num_partitions() + assert n_partitions == 200 + + +# for sort_merge, the result partitions should always be max(shuffle_join_default_partitions, max(left_partition_size, right_partition_size)) +@pytest.mark.parametrize("shuffle_join_default_partitions", [None, 20]) +def test_join_result_partitions_for_sortmerge(shuffle_join_default_partitions): + skip_invalid_join_strategies("sort_merge", "inner") + + if shuffle_join_default_partitions is None: + min_partitions = get_context().daft_execution_config.shuffle_join_default_partitions + else: + min_partitions = shuffle_join_default_partitions + + with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): + for partition_size in [1, min_partitions, min_partitions + 1]: + df_left = daft.from_pydict( + {"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]} + ) + df_left = df_left.into_partitions(partition_size) + + df_right = daft.from_pydict({"group": [i for i in range(50)], "value": [i for i in range(50)]}) + + df_right = df_right.into_partitions(50) + + actual = df_left.join(df_right, on="group", how="inner", strategy="sort_merge").collect() + + assert actual.num_partitions() == max(min_partitions, partition_size, 50) From 48a123afacdf826911f095ec7970ee6c997be86f Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Fri, 20 Sep 2024 14:59:33 -0500 Subject: [PATCH 06/35] [FEAT]: sql list operations (#2856) depends on https://github.com/Eventual-Inc/Daft/pull/2854 you can see the relevant diff [here](https://github.com/universalmind303/Daft/compare/sql-lists...universalmind303:Daft:sql-lists-2) --- src/daft-functions/src/lib.rs | 6 +- src/daft-functions/src/list/count.rs | 2 +- src/daft-functions/src/list/mod.rs | 24 +- .../src/{list_sort.rs => list/sort.rs} | 28 +- src/daft-sql/src/modules/list.rs | 256 +++++++++++++++++- tests/sql/test_list_exprs.py | 120 ++++++++ 6 files changed, 404 insertions(+), 32 deletions(-) rename src/daft-functions/src/{list_sort.rs => list/sort.rs} (76%) create mode 100644 tests/sql/test_list_exprs.py diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index f35edde44a..452c0e1cc5 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -5,7 +5,6 @@ pub mod float; pub mod hash; pub mod image; pub mod list; -pub mod list_sort; pub mod minhash; pub mod numeric; pub mod to_struct; @@ -29,10 +28,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent )?)?; parent.add_function(wrap_pyfunction_bound!(hash::python::hash, parent)?)?; - parent.add_function(wrap_pyfunction_bound!( - list_sort::python::list_sort, - parent - )?)?; + parent.add_function(wrap_pyfunction_bound!(minhash::python::minhash, parent)?)?; parent.add_function(wrap_pyfunction_bound!(numeric::cbrt::python::cbrt, parent)?)?; parent.add_function(wrap_pyfunction_bound!( diff --git a/src/daft-functions/src/list/count.rs b/src/daft-functions/src/list/count.rs index d00600163c..08e344e04a 100644 --- a/src/daft-functions/src/list/count.rs +++ b/src/daft-functions/src/list/count.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct ListCount { - mode: CountMode, + pub mode: CountMode, } #[typetag::serde] diff --git a/src/daft-functions/src/list/mod.rs b/src/daft-functions/src/list/mod.rs index fd8007b4b5..2ba3f197be 100644 --- a/src/daft-functions/src/list/mod.rs +++ b/src/daft-functions/src/list/mod.rs @@ -7,19 +7,22 @@ mod max; mod mean; mod min; mod slice; +mod sort; mod sum; -pub use chunk::list_chunk as chunk; -pub use count::list_count as count; -pub use explode::explode; -pub use get::list_get as get; -pub use join::list_join as join; -pub use max::list_max as max; -pub use mean::list_mean as mean; -pub use min::list_min as min; + +pub use chunk::{list_chunk as chunk, ListChunk}; +pub use count::{list_count as count, ListCount}; +pub use explode::{explode, Explode}; +pub use get::{list_get as get, ListGet}; +pub use join::{list_join as join, ListJoin}; +pub use max::{list_max as max, ListMax}; +pub use mean::{list_mean as mean, ListMean}; +pub use min::{list_min as min, ListMin}; #[cfg(feature = "python")] use pyo3::prelude::*; -pub use slice::list_slice as slice; -pub use sum::list_sum as sum; +pub use slice::{list_slice as slice, ListSlice}; +pub use sort::{list_sort as sort, ListSort}; +pub use sum::{list_sum as sum, ListSum}; #[cfg(feature = "python")] pub fn register_modules(parent: &Bound) -> PyResult<()> { @@ -35,6 +38,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(min::py_list_min, parent)?)?; parent.add_function(wrap_pyfunction_bound!(slice::py_list_slice, parent)?)?; parent.add_function(wrap_pyfunction_bound!(sum::py_list_sum, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(sort::py_list_sort, parent)?)?; Ok(()) } diff --git a/src/daft-functions/src/list_sort.rs b/src/daft-functions/src/list/sort.rs similarity index 76% rename from src/daft-functions/src/list_sort.rs rename to src/daft-functions/src/list/sort.rs index 34037f5a43..3d75e3fa48 100644 --- a/src/daft-functions/src/list_sort.rs +++ b/src/daft-functions/src/list/sort.rs @@ -2,15 +2,15 @@ use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; use daft_dsl::{ functions::{ScalarFunction, ScalarUDF}, - ExprRef, + lit, ExprRef, }; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -struct ListSortFunction {} +pub struct ListSort {} #[typetag::serde] -impl ScalarUDF for ListSortFunction { +impl ScalarUDF for ListSort { fn as_any(&self) -> &dyn std::any::Any { self } @@ -51,18 +51,20 @@ impl ScalarUDF for ListSortFunction { } } -pub fn list_sort(input: ExprRef, desc: ExprRef) -> ExprRef { - ScalarFunction::new(ListSortFunction {}, vec![input, desc]).into() +pub fn list_sort(input: ExprRef, desc: Option) -> ExprRef { + let desc = desc.unwrap_or_else(|| lit(false)); + ScalarFunction::new(ListSort {}, vec![input, desc]).into() } #[cfg(feature = "python")] -pub mod python { - use daft_dsl::python::PyExpr; - use pyo3::{pyfunction, PyResult}; +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; - #[pyfunction] - pub fn list_sort(expr: PyExpr, desc: PyExpr) -> PyResult { - let expr = super::list_sort(expr.into(), desc.into()); - Ok(expr.into()) - } +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_sort")] +pub fn py_list_sort(expr: PyExpr, desc: PyExpr) -> PyResult { + Ok(list_sort(expr.into(), Some(desc.into())).into()) } diff --git a/src/daft-sql/src/modules/list.rs b/src/daft-sql/src/modules/list.rs index 6f16d6693c..b9e52d9748 100644 --- a/src/daft-sql/src/modules/list.rs +++ b/src/daft-sql/src/modules/list.rs @@ -1,11 +1,261 @@ +use daft_core::prelude::CountMode; +use daft_dsl::{lit, Expr, LiteralValue}; + use super::SQLModule; -use crate::functions::SQLFunctions; +use crate::{ + error::PlannerError, + functions::{SQLFunction, SQLFunctions}, + unsupported_sql_err, +}; pub struct SQLModuleList; impl SQLModule for SQLModuleList { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::List as f; + fn register(parent: &mut SQLFunctions) { + parent.add_fn("list_chunk", SQLListChunk); + parent.add_fn("list_count", SQLListCount); + parent.add_fn("explode", SQLExplode); + parent.add_fn("unnest", SQLExplode); + // this is commonly called `array_to_string` in other SQL dialects + parent.add_fn("array_to_string", SQLListJoin); + // but we also want to support our `list_join` alias as well + parent.add_fn("list_join", SQLListJoin); + parent.add_fn("list_max", SQLListMax); + parent.add_fn("list_min", SQLListMin); + parent.add_fn("list_sum", SQLListSum); + parent.add_fn("list_mean", SQLListMean); + parent.add_fn("list_slice", SQLListSlice); + parent.add_fn("list_sort", SQLListSort); + // TODO } } + +pub struct SQLListChunk; + +impl SQLFunction for SQLListChunk { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input, chunk_size] => { + let input = planner.plan_function_arg(input)?; + let chunk_size = planner + .plan_function_arg(chunk_size) + .and_then(|arg| match arg.as_ref() { + Expr::Literal(LiteralValue::Int64(n)) => Ok(*n as usize), + _ => unsupported_sql_err!("Expected chunk size to be a number"), + })?; + Ok(daft_functions::list::chunk(input, chunk_size)) + } + _ => unsupported_sql_err!( + "invalid arguments for list_chunk. Expected list_chunk(expr, chunk_size)" + ), + } + } +} + +pub struct SQLListCount; + +impl SQLFunction for SQLListCount { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(daft_functions::list::count(input, CountMode::Valid)) + } + [input, count_mode] => { + let input = planner.plan_function_arg(input)?; + let mode = + planner + .plan_function_arg(count_mode) + .and_then(|arg| match arg.as_ref() { + Expr::Literal(LiteralValue::Utf8(s)) => { + s.parse().map_err(PlannerError::from) + } + _ => unsupported_sql_err!("Expected mode to be a string"), + })?; + Ok(daft_functions::list::count(input, mode)) + } + _ => unsupported_sql_err!("invalid arguments for list_count. Expected either list_count(expr) or list_count(expr, mode)"), + } + } +} + +pub struct SQLExplode; + +impl SQLFunction for SQLExplode { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(daft_functions::list::explode(input)) + } + _ => unsupported_sql_err!("Expected 1 argument"), + } + } +} + +pub struct SQLListJoin; + +impl SQLFunction for SQLListJoin { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input, separator] => { + let input = planner.plan_function_arg(input)?; + let separator = planner.plan_function_arg(separator)?; + Ok(daft_functions::list::join(input, separator)) + } + _ => unsupported_sql_err!( + "invalid arguments for list_join. Expected list_join(expr, separator)" + ), + } + } +} + +pub struct SQLListMax; + +impl SQLFunction for SQLListMax { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(daft_functions::list::max(input)) + } + _ => unsupported_sql_err!("invalid arguments for list_max. Expected list_max(expr)"), + } + } +} + +pub struct SQLListMean; + +impl SQLFunction for SQLListMean { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(daft_functions::list::mean(input)) + } + _ => unsupported_sql_err!("invalid arguments for list_mean. Expected list_mean(expr)"), + } + } +} + +pub struct SQLListMin; + +impl SQLFunction for SQLListMin { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(daft_functions::list::min(input)) + } + _ => unsupported_sql_err!("invalid arguments for list_min. Expected list_min(expr)"), + } + } +} + +pub struct SQLListSum; + +impl SQLFunction for SQLListSum { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(daft_functions::list::sum(input)) + } + _ => unsupported_sql_err!("invalid arguments for list_sum. Expected list_sum(expr)"), + } + } +} + +pub struct SQLListSlice; + +impl SQLFunction for SQLListSlice { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input, start, end] => { + let input = planner.plan_function_arg(input)?; + let start = planner.plan_function_arg(start)?; + let end = planner.plan_function_arg(end)?; + Ok(daft_functions::list::slice(input, start, end)) + } + _ => unsupported_sql_err!( + "invalid arguments for list_slice. Expected list_slice(expr, start, end)" + ), + } + } +} + +pub struct SQLListSort; + +impl SQLFunction for SQLListSort { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(daft_functions::list::sort(input, None)) + } + [input, order] => { + let input = planner.plan_function_arg(input)?; + use sqlparser::ast::{ + Expr::Identifier as SQLIdent, FunctionArg::Unnamed, + FunctionArgExpr::Expr as SQLExpr, + }; + + let order = match order { + Unnamed(SQLExpr(SQLIdent(ident))) => { + match ident.value.to_lowercase().as_str() { + "asc" => lit(false), + "desc" => lit(true), + _ => unsupported_sql_err!("invalid order for list_sort"), + } + } + _ => unsupported_sql_err!("invalid order for list_sort"), + }; + Ok(daft_functions::list::sort(input, Some(order))) + } + _ => unsupported_sql_err!( + "invalid arguments for list_sort. Expected list_sort(expr, ASC|DESC)" + ), + } + } +} diff --git a/tests/sql/test_list_exprs.py b/tests/sql/test_list_exprs.py new file mode 100644 index 0000000000..9b76735e44 --- /dev/null +++ b/tests/sql/test_list_exprs.py @@ -0,0 +1,120 @@ +import pyarrow as pa +import pytest + +import daft +from daft import col, context +from daft.daft import CountMode +from daft.sql.sql import SQLCatalog + + +def test_list_chunk(): + df = daft.from_pydict( + { + "col": pa.array([], type=pa.list_(pa.int64())), + "fixed_col": pa.array([], type=pa.list_(pa.int64(), 2)), + } + ) + catalog = SQLCatalog({"test": df}) + expected = df.select( + col("col").list.chunk(1).alias("col1"), + col("col").list.chunk(2).alias("col2"), + col("col").list.chunk(1000).alias("col3"), + col("fixed_col").list.chunk(1).alias("fixed_col1"), + col("fixed_col").list.chunk(2).alias("fixed_col2"), + col("fixed_col").list.chunk(1000).alias("fixed_col3"), + ) + + actual = daft.sql( + """ + SELECT + list_chunk(col, 1) as col1, + list_chunk(col, 2) as col2, + list_chunk(col, 1000) as col3, + list_chunk(fixed_col, 1) as fixed_col1, + list_chunk(fixed_col, 2) as fixed_col2, + list_chunk(fixed_col, 1000) as fixed_col3 + FROM test + """, + catalog=catalog, + ).collect() + assert actual.to_pydict() == expected.to_pydict() + + +def test_list_counts(): + df = daft.from_pydict({"col": [[1, 2, 3], [1, 2], [1, None, 4], []]}) + catalog = SQLCatalog({"test": df}) + expected = df.select( + col("col").list.count().alias("count_valid"), + col("col").list.count(CountMode.All).alias("count_all"), + col("col").list.count(CountMode.Null).alias("count_null"), + ).collect() + actual = daft.sql( + """ + SELECT + list_count(col) as count_valid, + list_count(col, 'all') as count_all, + list_count(col, 'null') as count_null + FROM test + """, + catalog=catalog, + ).collect() + assert actual.to_pydict() == expected.to_pydict() + + +def test_list_explode(): + if context.get_context().daft_execution_config.enable_native_executor is True: + pytest.skip("Native executor fails for these tests") + df = daft.from_pydict({"col": [[1, 2, 3], [1, 2], [1, None, 4], []]}) + catalog = SQLCatalog({"test": df}) + expected = df.explode(col("col")) + actual = daft.sql("SELECT unnest(col) as col FROM test", catalog=catalog).collect() + assert actual.to_pydict() == expected.to_pydict() + # test with alias + actual = daft.sql("SELECT explode(col) as col FROM test", catalog=catalog).collect() + assert actual.to_pydict() == expected.to_pydict() + + +def test_list_join(): + df = daft.from_pydict({"col": [None, [], ["a"], [None], ["a", "a"], ["a", None], ["a", None, "a"]]}) + catalog = SQLCatalog({"test": df}) + expected = df.select(col("col").list.join(",")) + actual = daft.sql("SELECT list_join(col, ',') FROM test", catalog=catalog).collect() + assert actual.to_pydict() == expected.to_pydict() + # make sure it works with the `array_to_string` function too + actual = daft.sql("SELECT array_to_string(col, ',') FROM test", catalog=catalog).collect() + assert actual.to_pydict() == expected.to_pydict() + + +def test_various_list_ops(): + df = daft.from_pydict({"col": [[1, 2, 3], [1, 2], [1, None, 4], []]}) + catalog = SQLCatalog({"test": df}) + expected = df.select( + col("col").list.min().alias("min"), + col("col").list.max().alias("max"), + col("col").list.mean().alias("mean"), + col("col").list.sum().alias("sum"), + col("col").list.sort().alias("sort"), + col("col").list.sort(True).alias("sort_desc"), + col("col").list.sort(False).alias("sort_asc"), + col("col").list.sort(True).alias("sort_desc_upper"), + col("col").list.sort(False).alias("sort_asc_upper"), + col("col").list.slice(1, 2).alias("slice"), + ).collect() + actual = daft.sql( + """ + SELECT + list_min(col) as min, + list_max(col) as max, + list_mean(col) as mean, + list_sum(col) as sum, + list_sort(col) as sort, + list_sort(col, desc) as sort_desc, + list_sort(col, asc) as sort_asc, + list_sort(col, DESC) as sort_desc_upper, + list_sort(col, ASC) as sort_asc_upper, + list_slice(col, 1, 2) as slice + FROM test + """, + catalog=catalog, + ).collect() + assert actual.to_pydict() == expected.to_pydict() From c5b70621d7410ca7a2c4f3f840174764bae4b27c Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Fri, 20 Sep 2024 15:26:06 -0500 Subject: [PATCH 07/35] [FEAT]: SQL temporal functions (#2858) --- Cargo.lock | 1 + daft/daft/__init__.pyi | 24 ++- daft/expressions/expressions.py | 20 +- src/daft-dsl/src/functions/mod.rs | 5 +- src/daft-dsl/src/functions/temporal/date.rs | 42 ---- src/daft-dsl/src/functions/temporal/day.rs | 42 ---- .../src/functions/temporal/day_of_week.rs | 42 ---- src/daft-dsl/src/functions/temporal/hour.rs | 42 ---- src/daft-dsl/src/functions/temporal/minute.rs | 42 ---- src/daft-dsl/src/functions/temporal/mod.rs | 135 ------------ src/daft-dsl/src/functions/temporal/month.rs | 42 ---- src/daft-dsl/src/functions/temporal/second.rs | 42 ---- src/daft-dsl/src/functions/temporal/time.rs | 49 ----- .../src/functions/temporal/truncate.rs | 56 ----- src/daft-dsl/src/functions/temporal/year.rs | 42 ---- src/daft-dsl/src/python.rs | 50 ----- src/daft-functions/Cargo.toml | 1 + src/daft-functions/src/lib.rs | 2 + src/daft-functions/src/temporal/mod.rs | 195 ++++++++++++++++++ src/daft-functions/src/temporal/truncate.rs | 78 +++++++ src/daft-sql/src/modules/temporal.rs | 62 +++++- tests/sql/test_temporal_exprs.py | 50 +++++ 22 files changed, 410 insertions(+), 654 deletions(-) delete mode 100644 src/daft-dsl/src/functions/temporal/date.rs delete mode 100644 src/daft-dsl/src/functions/temporal/day.rs delete mode 100644 src/daft-dsl/src/functions/temporal/day_of_week.rs delete mode 100644 src/daft-dsl/src/functions/temporal/hour.rs delete mode 100644 src/daft-dsl/src/functions/temporal/minute.rs delete mode 100644 src/daft-dsl/src/functions/temporal/mod.rs delete mode 100644 src/daft-dsl/src/functions/temporal/month.rs delete mode 100644 src/daft-dsl/src/functions/temporal/second.rs delete mode 100644 src/daft-dsl/src/functions/temporal/time.rs delete mode 100644 src/daft-dsl/src/functions/temporal/truncate.rs delete mode 100644 src/daft-dsl/src/functions/temporal/year.rs create mode 100644 src/daft-functions/src/temporal/mod.rs create mode 100644 src/daft-functions/src/temporal/truncate.rs create mode 100644 tests/sql/test_temporal_exprs.py diff --git a/Cargo.lock b/Cargo.lock index 90e4745ba8..4815488c2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1825,6 +1825,7 @@ dependencies = [ "daft-image", "daft-io", "futures", + "paste", "pyo3", "serde", "snafu", diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 83f1eff059..da4d3e6abc 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1103,16 +1103,6 @@ class PyExpr: def __repr__(self) -> str: ... def __hash__(self) -> int: ... def __reduce__(self) -> tuple: ... - def dt_date(self) -> PyExpr: ... - def dt_day(self) -> PyExpr: ... - def dt_hour(self) -> PyExpr: ... - def dt_minute(self) -> PyExpr: ... - def dt_second(self) -> PyExpr: ... - def dt_time(self) -> PyExpr: ... - def dt_month(self) -> PyExpr: ... - def dt_year(self) -> PyExpr: ... - def dt_day_of_week(self) -> PyExpr: ... - def dt_truncate(self, interval: str, relative_to: PyExpr) -> PyExpr: ... def utf8_endswith(self, pattern: PyExpr) -> PyExpr: ... def utf8_startswith(self, pattern: PyExpr) -> PyExpr: ... def utf8_contains(self, pattern: PyExpr) -> PyExpr: ... @@ -1251,6 +1241,20 @@ def fill_nan(expr: PyExpr, fill_value: PyExpr) -> PyExpr: ... # --- def json_query(expr: PyExpr, query: str) -> PyExpr: ... +# --- +# expr.dt namespace +# --- +def dt_date(expr: PyExpr) -> PyExpr: ... +def dt_day(expr: PyExpr) -> PyExpr: ... +def dt_hour(expr: PyExpr) -> PyExpr: ... +def dt_minute(expr: PyExpr) -> PyExpr: ... +def dt_second(expr: PyExpr) -> PyExpr: ... +def dt_time(expr: PyExpr) -> PyExpr: ... +def dt_month(expr: PyExpr) -> PyExpr: ... +def dt_year(expr: PyExpr) -> PyExpr: ... +def dt_day_of_week(expr: PyExpr) -> PyExpr: ... +def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ... + # --- # expr.list namespace # --- diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 986bad70fd..d5482220df 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1491,7 +1491,7 @@ def date(self) -> Expression: Returns: Expression: a Date expression """ - return Expression._from_pyexpr(self._expr.dt_date()) + return Expression._from_pyexpr(native.dt_date(self._expr)) def day(self) -> Expression: """Retrieves the day for a datetime column @@ -1526,7 +1526,7 @@ def day(self) -> Expression: Returns: Expression: a UInt32 expression with just the day extracted from a datetime column """ - return Expression._from_pyexpr(self._expr.dt_day()) + return Expression._from_pyexpr(native.dt_day(self._expr)) def hour(self) -> Expression: """Retrieves the day for a datetime column @@ -1561,7 +1561,7 @@ def hour(self) -> Expression: Returns: Expression: a UInt32 expression with just the day extracted from a datetime column """ - return Expression._from_pyexpr(self._expr.dt_hour()) + return Expression._from_pyexpr(native.dt_hour(self._expr)) def minute(self) -> Expression: """Retrieves the minute for a datetime column @@ -1596,7 +1596,7 @@ def minute(self) -> Expression: Returns: Expression: a UInt32 expression with just the minute extracted from a datetime column """ - return Expression._from_pyexpr(self._expr.dt_minute()) + return Expression._from_pyexpr(native.dt_minute(self._expr)) def second(self) -> Expression: """Retrieves the second for a datetime column @@ -1631,7 +1631,7 @@ def second(self) -> Expression: Returns: Expression: a UInt32 expression with just the second extracted from a datetime column """ - return Expression._from_pyexpr(self._expr.dt_second()) + return Expression._from_pyexpr(native.dt_second(self._expr)) def time(self) -> Expression: """Retrieves the time for a datetime column @@ -1666,7 +1666,7 @@ def time(self) -> Expression: Returns: Expression: a Time expression """ - return Expression._from_pyexpr(self._expr.dt_time()) + return Expression._from_pyexpr(native.dt_time(self._expr)) def month(self) -> Expression: """Retrieves the month for a datetime column @@ -1699,7 +1699,7 @@ def month(self) -> Expression: Returns: Expression: a UInt32 expression with just the month extracted from a datetime column """ - return Expression._from_pyexpr(self._expr.dt_month()) + return Expression._from_pyexpr(native.dt_month(self._expr)) def year(self) -> Expression: """Retrieves the year for a datetime column @@ -1733,7 +1733,7 @@ def year(self) -> Expression: Returns: Expression: a UInt32 expression with just the year extracted from a datetime column """ - return Expression._from_pyexpr(self._expr.dt_year()) + return Expression._from_pyexpr(native.dt_year(self._expr)) def day_of_week(self) -> Expression: """Retrieves the day of the week for a datetime column, starting at 0 for Monday and ending at 6 for Sunday @@ -1766,7 +1766,7 @@ def day_of_week(self) -> Expression: Returns: Expression: a UInt32 expression with just the day_of_week extracted from a datetime column """ - return Expression._from_pyexpr(self._expr.dt_day_of_week()) + return Expression._from_pyexpr(native.dt_day_of_week(self._expr)) def truncate(self, interval: str, relative_to: Expression | None = None) -> Expression: """Truncates the datetime column to the specified interval @@ -1804,7 +1804,7 @@ def truncate(self, interval: str, relative_to: Expression | None = None) -> Expr Expression: a DateTime expression truncated to the specified interval """ relative_to = Expression._to_expression(relative_to) - return Expression._from_pyexpr(self._expr.dt_truncate(interval, relative_to._expr)) + return Expression._from_pyexpr(native.dt_truncate(self._expr, interval, relative_to._expr)) class ExpressionStringNamespace(ExpressionNamespace): diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 216c79a2f9..4ff2375b68 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -4,7 +4,6 @@ pub mod partitioning; pub mod scalar; pub mod sketch; pub mod struct_; -pub mod temporal; pub mod utf8; use std::{ @@ -19,7 +18,7 @@ use serde::{Deserialize, Serialize}; use self::{ map::MapExpr, numeric::NumericExpr, partitioning::PartitioningExpr, sketch::SketchExpr, - struct_::StructExpr, temporal::TemporalExpr, utf8::Utf8Expr, + struct_::StructExpr, utf8::Utf8Expr, }; use crate::{Expr, ExprRef, Operator}; @@ -30,7 +29,6 @@ use python::PythonUDF; pub enum FunctionExpr { Numeric(NumericExpr), Utf8(Utf8Expr), - Temporal(TemporalExpr), Map(MapExpr), Sketch(SketchExpr), Struct(StructExpr), @@ -56,7 +54,6 @@ impl FunctionExpr { match self { Numeric(expr) => expr.get_evaluator(), Utf8(expr) => expr.get_evaluator(), - Temporal(expr) => expr.get_evaluator(), Map(expr) => expr.get_evaluator(), Sketch(expr) => expr.get_evaluator(), Struct(expr) => expr.get_evaluator(), diff --git a/src/daft-dsl/src/functions/temporal/date.rs b/src/daft-dsl/src/functions/temporal/date.rs deleted file mode 100644 index 0d7b70ab83..0000000000 --- a/src/daft-dsl/src/functions/temporal/date.rs +++ /dev/null @@ -1,42 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct DateEvaluator {} - -impl FunctionEvaluator for DateEvaluator { - fn fn_name(&self) -> &'static str { - "date" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::Date)) - } - Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to date to be temporal, got {}", - field.dtype - ))), - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_date(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/day.rs b/src/daft-dsl/src/functions/temporal/day.rs deleted file mode 100644 index bb06d6f5fa..0000000000 --- a/src/daft-dsl/src/functions/temporal/day.rs +++ /dev/null @@ -1,42 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct DayEvaluator {} - -impl FunctionEvaluator for DayEvaluator { - fn fn_name(&self) -> &'static str { - "day" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::UInt32)) - } - Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to day to be temporal, got {}", - field.dtype - ))), - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_day(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/day_of_week.rs b/src/daft-dsl/src/functions/temporal/day_of_week.rs deleted file mode 100644 index 23fe76947b..0000000000 --- a/src/daft-dsl/src/functions/temporal/day_of_week.rs +++ /dev/null @@ -1,42 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct DayOfWeekEvaluator {} - -impl FunctionEvaluator for DayOfWeekEvaluator { - fn fn_name(&self) -> &'static str { - "day_of_week" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::UInt32)) - } - Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to day to be temporal, got {}", - field.dtype - ))), - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_day_of_week(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/hour.rs b/src/daft-dsl/src/functions/temporal/hour.rs deleted file mode 100644 index e3f775577a..0000000000 --- a/src/daft-dsl/src/functions/temporal/hour.rs +++ /dev/null @@ -1,42 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct HourEvaluator {} - -impl FunctionEvaluator for HourEvaluator { - fn fn_name(&self) -> &'static str { - "hour" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::UInt32)) - } - Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to hour to be temporal, got {}", - field.dtype - ))), - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_hour(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/minute.rs b/src/daft-dsl/src/functions/temporal/minute.rs deleted file mode 100644 index ffdf5f29ea..0000000000 --- a/src/daft-dsl/src/functions/temporal/minute.rs +++ /dev/null @@ -1,42 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct MinuteEvaluator {} - -impl FunctionEvaluator for MinuteEvaluator { - fn fn_name(&self) -> &'static str { - "minute" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::UInt32)) - } - Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to minute to be temporal, got {}", - field.dtype - ))), - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_minute(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/mod.rs b/src/daft-dsl/src/functions/temporal/mod.rs deleted file mode 100644 index 668d72f164..0000000000 --- a/src/daft-dsl/src/functions/temporal/mod.rs +++ /dev/null @@ -1,135 +0,0 @@ -mod date; -mod day; -mod day_of_week; -mod hour; -mod minute; -mod month; -mod second; -mod time; -mod truncate; -mod year; - -use serde::{Deserialize, Serialize}; - -use super::FunctionEvaluator; -use crate::{ - functions::temporal::{ - date::DateEvaluator, day::DayEvaluator, day_of_week::DayOfWeekEvaluator, - hour::HourEvaluator, minute::MinuteEvaluator, month::MonthEvaluator, - second::SecondEvaluator, time::TimeEvaluator, truncate::TruncateEvaluator, - year::YearEvaluator, - }, - Expr, ExprRef, -}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum TemporalExpr { - Day, - Hour, - Minute, - Second, - Month, - Year, - DayOfWeek, - Date, - Time, - Truncate(String), -} - -impl TemporalExpr { - #[inline] - pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use TemporalExpr::*; - match self { - Day => &DayEvaluator {}, - Hour => &HourEvaluator {}, - Month => &MonthEvaluator {}, - Year => &YearEvaluator {}, - DayOfWeek => &DayOfWeekEvaluator {}, - Date => &DateEvaluator {}, - Minute => &MinuteEvaluator {}, - Second => &SecondEvaluator {}, - Time => &TimeEvaluator {}, - Truncate(..) => &TruncateEvaluator {}, - } - } -} - -pub fn date(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Date), - inputs: vec![input], - } - .into() -} - -pub fn day(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Day), - inputs: vec![input], - } - .into() -} - -pub fn hour(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Hour), - inputs: vec![input], - } - .into() -} - -pub fn minute(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Minute), - inputs: vec![input], - } - .into() -} - -pub fn second(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Second), - inputs: vec![input], - } - .into() -} - -pub fn time(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Time), - inputs: vec![input], - } - .into() -} - -pub fn month(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Month), - inputs: vec![input], - } - .into() -} - -pub fn year(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Year), - inputs: vec![input], - } - .into() -} - -pub fn day_of_week(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::DayOfWeek), - inputs: vec![input], - } - .into() -} - -pub fn truncate(input: ExprRef, freq: &str, relative_to: ExprRef) -> Expr { - Expr::Function { - func: super::FunctionExpr::Temporal(TemporalExpr::Truncate(freq.to_string())), - inputs: vec![input, relative_to], - } -} diff --git a/src/daft-dsl/src/functions/temporal/month.rs b/src/daft-dsl/src/functions/temporal/month.rs deleted file mode 100644 index b93af55090..0000000000 --- a/src/daft-dsl/src/functions/temporal/month.rs +++ /dev/null @@ -1,42 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct MonthEvaluator {} - -impl FunctionEvaluator for MonthEvaluator { - fn fn_name(&self) -> &'static str { - "month" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::UInt32)) - } - Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to month to be temporal, got {}", - field.dtype - ))), - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_month(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/second.rs b/src/daft-dsl/src/functions/temporal/second.rs deleted file mode 100644 index 490803f5a2..0000000000 --- a/src/daft-dsl/src/functions/temporal/second.rs +++ /dev/null @@ -1,42 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct SecondEvaluator {} - -impl FunctionEvaluator for SecondEvaluator { - fn fn_name(&self) -> &'static str { - "second" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::UInt32)) - } - Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to second to be temporal, got {}", - field.dtype - ))), - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_second(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/time.rs b/src/daft-dsl/src/functions/temporal/time.rs deleted file mode 100644 index 32ff78e459..0000000000 --- a/src/daft-dsl/src/functions/temporal/time.rs +++ /dev/null @@ -1,49 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct TimeEvaluator {} - -impl FunctionEvaluator for TimeEvaluator { - fn fn_name(&self) -> &'static str { - "time" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) => match field.dtype { - DataType::Time(_) => Ok(field), - DataType::Timestamp(tu, _) => { - let tu = match tu { - TimeUnit::Nanoseconds => TimeUnit::Nanoseconds, - _ => TimeUnit::Microseconds, - }; - Ok(Field::new(field.name, DataType::Time(tu))) - } - _ => Err(DaftError::TypeError(format!( - "Expected input to time to be time or timestamp, got {}", - field.dtype - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_time(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/truncate.rs b/src/daft-dsl/src/functions/temporal/truncate.rs deleted file mode 100644 index 785486dc7f..0000000000 --- a/src/daft-dsl/src/functions/temporal/truncate.rs +++ /dev/null @@ -1,56 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, TemporalExpr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct TruncateEvaluator {} - -impl FunctionEvaluator for TruncateEvaluator { - fn fn_name(&self) -> &'static str { - "truncate" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input, relative_to] => match (input.to_field(schema), relative_to.to_field(schema)) { - (Ok(input_field), Ok(relative_to_field)) - if input_field.dtype.is_temporal() - && (relative_to_field.dtype.is_temporal() - || relative_to_field.dtype.is_null()) => - { - Ok(Field::new(input_field.name, input_field.dtype)) - } - (Ok(input_field), Ok(relative_to_field)) => Err(DaftError::TypeError(format!( - "Expected temporal input args, got {} and {}", - input_field.dtype, relative_to_field.dtype - ))), - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], func: &FunctionExpr) -> DaftResult { - match inputs { - [input, relative_to] => { - let freq = match func { - FunctionExpr::Temporal(TemporalExpr::Truncate(freq)) => freq, - _ => { - return Err(DaftError::ValueError( - "Expected Temporal function".to_string(), - )) - } - }; - input.dt_truncate(freq, relative_to) - } - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/temporal/year.rs b/src/daft-dsl/src/functions/temporal/year.rs deleted file mode 100644 index 5557926a04..0000000000 --- a/src/daft-dsl/src/functions/temporal/year.rs +++ /dev/null @@ -1,42 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct YearEvaluator {} - -impl FunctionEvaluator for YearEvaluator { - fn fn_name(&self) -> &'static str { - "year" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => match input.to_field(schema) { - Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::Int32)) - } - Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to year to be temporal, got {}", - field.dtype - ))), - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input] => input.dt_year(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index f9ebbce6c3..7c5a1d7930 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -572,56 +572,6 @@ impl PyExpr { hasher.finish() } - pub fn dt_date(&self) -> PyResult { - use functions::temporal::date; - Ok(date(self.into()).into()) - } - - pub fn dt_day(&self) -> PyResult { - use functions::temporal::day; - Ok(day(self.into()).into()) - } - - pub fn dt_hour(&self) -> PyResult { - use functions::temporal::hour; - Ok(hour(self.into()).into()) - } - - pub fn dt_minute(&self) -> PyResult { - use functions::temporal::minute; - Ok(minute(self.into()).into()) - } - - pub fn dt_second(&self) -> PyResult { - use functions::temporal::second; - Ok(second(self.into()).into()) - } - - pub fn dt_time(&self) -> PyResult { - use functions::temporal::time; - Ok(time(self.into()).into()) - } - - pub fn dt_month(&self) -> PyResult { - use functions::temporal::month; - Ok(month(self.into()).into()) - } - - pub fn dt_year(&self) -> PyResult { - use functions::temporal::year; - Ok(year(self.into()).into()) - } - - pub fn dt_day_of_week(&self) -> PyResult { - use functions::temporal::day_of_week; - Ok(day_of_week(self.into()).into()) - } - - pub fn dt_truncate(&self, interval: &str, relative_to: &Self) -> PyResult { - use functions::temporal::truncate; - Ok(truncate(self.into(), interval, relative_to.expr.clone()).into()) - } - pub fn utf8_endswith(&self, pattern: &Self) -> PyResult { use crate::functions::utf8::endswith; Ok(endswith(self.into(), pattern.expr.clone()).into()) diff --git a/src/daft-functions/Cargo.toml b/src/daft-functions/Cargo.toml index e411d32614..b965e9f417 100644 --- a/src/daft-functions/Cargo.toml +++ b/src/daft-functions/Cargo.toml @@ -8,6 +8,7 @@ daft-dsl = {path = "../daft-dsl", default-features = false} daft-image = {path = "../daft-image", default-features = false} daft-io = {path = "../daft-io", default-features = false} futures = {workspace = true} +paste = "1.0.15" pyo3 = {workspace = true, optional = true} tiktoken-rs = {workspace = true} tokio = {workspace = true} diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index 452c0e1cc5..0976f17c21 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -7,6 +7,7 @@ pub mod image; pub mod list; pub mod minhash; pub mod numeric; +pub mod temporal; pub mod to_struct; pub mod tokenize; pub mod uri; @@ -47,6 +48,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(uri::python::url_upload, parent)?)?; image::register_modules(parent)?; float::register_modules(parent)?; + temporal::register_modules(parent)?; list::register_modules(parent)?; Ok(()) } diff --git a/src/daft-functions/src/temporal/mod.rs b/src/daft-functions/src/temporal/mod.rs new file mode 100644 index 0000000000..314546fe77 --- /dev/null +++ b/src/daft-functions/src/temporal/mod.rs @@ -0,0 +1,195 @@ +pub mod truncate; +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema, TimeUnit}, + series::Series, +}; +#[cfg(feature = "python")] +use daft_dsl::python::PyExpr; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +#[cfg(feature = "python")] +use pyo3::{prelude::*, pyfunction, PyResult}; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "python")] +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_function(wrap_pyfunction_bound!(py_dt_date, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(py_dt_day, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(py_dt_day_of_week, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(py_dt_hour, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(py_dt_minute, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(py_dt_month, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(py_dt_second, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(py_dt_time, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(py_dt_year, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(truncate::py_dt_truncate, parent)?)?; + Ok(()) +} + +macro_rules! impl_temporal { + // pyo3 macro can't handle any expressions other than a 'literal', so we have to redundantly pass it in via $py_name + ($name:ident, $dt:ident, $py_name:literal, $dtype:ident) => { + paste::paste! { + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] + pub struct $name; + + #[typetag::serde] + impl ScalarUDF for $name { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + stringify!([ < $name:snake:lower > ]) + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] => match input.to_field(schema) { + Ok(field) if field.dtype.is_temporal() => { + Ok(Field::new(field.name, DataType::$dtype)) + } + Ok(field) => Err(DaftError::TypeError(format!( + "Expected input to {} to be temporal, got {}", + self.name(), + field.dtype + ))), + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input] => input.$dt(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + } + + pub fn $dt(input: ExprRef) -> ExprRef { + ScalarFunction::new($name {}, vec![input]).into() + } + + #[pyfunction] + #[pyo3(name = $py_name)] + #[cfg(feature = "python")] + pub fn [](expr: PyExpr) -> PyResult { + Ok($dt(expr.into()).into()) + } + } + }; +} + +impl_temporal!(Date, dt_date, "dt_date", Date); +impl_temporal!(Day, dt_day, "dt_day", UInt32); +impl_temporal!(Hour, dt_hour, "dt_hour", UInt32); +impl_temporal!(DayOfWeek, dt_day_of_week, "dt_day_of_week", UInt32); +impl_temporal!(Minute, dt_minute, "dt_minute", UInt32); +impl_temporal!(Month, dt_month, "dt_month", UInt32); +impl_temporal!(Second, dt_second, "dt_second", UInt32); +impl_temporal!(Year, dt_year, "dt_year", Int32); + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Time; + +#[typetag::serde] +impl ScalarUDF for Time { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "time" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] => match input.to_field(schema) { + Ok(field) => match field.dtype { + DataType::Time(_) => Ok(field), + DataType::Timestamp(tu, _) => { + let tu = match tu { + TimeUnit::Nanoseconds => TimeUnit::Nanoseconds, + _ => TimeUnit::Microseconds, + }; + Ok(Field::new(field.name, DataType::Time(tu))) + } + _ => Err(DaftError::TypeError(format!( + "Expected input to time to be time or timestamp, got {}", + field.dtype + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input] => input.dt_time(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} + +pub fn dt_time(input: ExprRef) -> ExprRef { + ScalarFunction::new(Time {}, vec![input]).into() +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "dt_time")] +pub fn py_dt_time(expr: PyExpr) -> PyResult { + Ok(dt_time(expr.into()).into()) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use super::truncate::Truncate; + + #[test] + fn test_fn_name() { + use super::*; + let cases: Vec<(Arc, &str)> = vec![ + (Arc::new(Date), "date"), + (Arc::new(Day), "day"), + (Arc::new(Hour), "hour"), + (Arc::new(DayOfWeek), "day_of_week"), + (Arc::new(Minute), "minute"), + (Arc::new(Month), "month"), + (Arc::new(Second), "second"), + (Arc::new(Time), "time"), + (Arc::new(Year), "year"), + ( + Arc::new(Truncate { + interval: "".into(), + }), + "truncate", + ), + ]; + + for (f, name) in cases { + assert_eq!(f.name(), name); + } + } +} diff --git a/src/daft-functions/src/temporal/truncate.rs b/src/daft-functions/src/temporal/truncate.rs new file mode 100644 index 0000000000..2453a966ab --- /dev/null +++ b/src/daft-functions/src/temporal/truncate.rs @@ -0,0 +1,78 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::prelude::*; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Truncate { + pub(super) interval: String, +} + +#[typetag::serde] +impl ScalarUDF for Truncate { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "truncate" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input, relative_to] => { + let input_field = input.to_field(schema)?; + let relative_to_field = relative_to.to_field(schema)?; + if input_field.dtype.is_temporal() + && (relative_to_field.dtype.is_temporal() || relative_to_field.dtype.is_null()) + { + Ok(Field::new(input_field.name, input_field.dtype)) + } else { + Err(DaftError::TypeError(format!( + "Expected temporal input args, got {} and {}", + input_field.dtype, relative_to_field.dtype + ))) + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [input, relative_to] => input.dt_truncate(&self.interval, relative_to), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} + +pub fn dt_truncate>(input: ExprRef, interval: S, relative_to: ExprRef) -> ExprRef { + ScalarFunction::new( + Truncate { + interval: interval.into(), + }, + vec![input, relative_to], + ) + .into() +} + +#[cfg(feature = "python")] +use daft_dsl::python::PyExpr; +#[cfg(feature = "python")] +use pyo3::{pyfunction, PyResult}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "dt_truncate")] +pub fn py_dt_truncate(expr: PyExpr, interval: &str, relative_to: PyExpr) -> PyResult { + Ok(dt_truncate(expr.into(), interval, relative_to.into()).into()) +} diff --git a/src/daft-sql/src/modules/temporal.rs b/src/daft-sql/src/modules/temporal.rs index 7bced4f980..58687724fa 100644 --- a/src/daft-sql/src/modules/temporal.rs +++ b/src/daft-sql/src/modules/temporal.rs @@ -1,11 +1,65 @@ +use daft_dsl::ExprRef; +use daft_functions::temporal::*; +use sqlparser::ast::FunctionArg; + use super::SQLModule; -use crate::functions::SQLFunctions; +use crate::{ + error::SQLPlannerResult, + functions::{SQLFunction, SQLFunctions}, + unsupported_sql_err, +}; pub struct SQLModuleTemporal; impl SQLModule for SQLModuleTemporal { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::Temporal as f; - // TODO + fn register(parent: &mut SQLFunctions) { + parent.add_fn("date", SQLDate); + parent.add_fn("day", SQLDay); + parent.add_fn("dayofweek", SQLDayOfWeek); + parent.add_fn("hour", SQLHour); + parent.add_fn("minute", SQLMinute); + parent.add_fn("month", SQLMonth); + parent.add_fn("second", SQLSecond); + parent.add_fn("year", SQLYear); + parent.add_fn("time", SQLTime); + + // TODO: Add truncate + // Our `dt_truncate` function has vastly different semantics than SQL `DATE_TRUNCATE` function. } } + +macro_rules! temporal { + ($name:ident, $fn_name:ident) => { + pub struct $name; + + impl SQLFunction for $name { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + + Ok($fn_name(input)) + } + _ => unsupported_sql_err!( + "Invalid arguments for {}: '{inputs:?}'", + stringify!($fn_name) + ), + } + } + } + }; +} + +temporal!(SQLDate, dt_date); +temporal!(SQLDay, dt_day); +temporal!(SQLDayOfWeek, dt_day_of_week); +temporal!(SQLHour, dt_hour); +temporal!(SQLMinute, dt_minute); +temporal!(SQLMonth, dt_month); +temporal!(SQLSecond, dt_second); +temporal!(SQLYear, dt_year); +temporal!(SQLTime, dt_time); diff --git a/tests/sql/test_temporal_exprs.py b/tests/sql/test_temporal_exprs.py new file mode 100644 index 0000000000..5383fd20aa --- /dev/null +++ b/tests/sql/test_temporal_exprs.py @@ -0,0 +1,50 @@ +import datetime + +import daft +from daft.sql.sql import SQLCatalog + + +def test_temporals(): + df = daft.from_pydict( + { + "datetimes": [ + datetime.datetime(2021, 1, 1, 23, 59, 58), + datetime.datetime(2021, 1, 2, 0, 0, 0), + datetime.datetime(2021, 1, 2, 1, 2, 3), + datetime.datetime(2021, 1, 2, 1, 2, 3), + datetime.datetime(1999, 1, 1, 1, 1, 1), + None, + ] + } + ) + catalog = SQLCatalog({"test": df}) + print(df) + + expected = df.select( + daft.col("datetimes").dt.date().alias("date"), + daft.col("datetimes").dt.day().alias("day"), + daft.col("datetimes").dt.day_of_week().alias("day_of_week"), + daft.col("datetimes").dt.hour().alias("hour"), + daft.col("datetimes").dt.minute().alias("minute"), + daft.col("datetimes").dt.month().alias("month"), + daft.col("datetimes").dt.second().alias("second"), + daft.col("datetimes").dt.year().alias("year"), + ).collect() + + actual = daft.sql( + """ + SELECT + date(datetimes) as date, + day(datetimes) as day, + dayofweek(datetimes) as day_of_week, + hour(datetimes) as hour, + minute(datetimes) as minute, + month(datetimes) as month, + second(datetimes) as second, + year(datetimes) as year, + FROM test + """, + catalog=catalog, + ).collect() + + assert actual.to_pydict() == expected.to_pydict() From 2c13f17d71757ebe4947efcb6f5a6862fc926d41 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Fri, 20 Sep 2024 14:00:45 -0700 Subject: [PATCH 08/35] [FEAT] Iceberg partitioned writes (#2842) This also changes the behavior of some of the partitioning functions in ways that I would consider as bug fixes. However @samster25 maybe you should take a look at them to make sure their new behavior is correct. The changes: - truncate function now renames columns to `{}_trunc` instead of `{}_truncate` to match Spark behavior - day partitioning now returns an Int32Array instead of a DateArray. I believe the past behavior was there to match pyiceberg, but talking to the pyiceberg team, this actually seems like a bug. I plan on making a PR to pyiceberg to fix this, but I have also moved and fixed the buggy logic to our codebase so that it works with past versions of pyiceberg as well. --- daft/daft/__init__.pyi | 3 +- daft/dataframe/dataframe.py | 37 +- daft/execution/execution_step.py | 5 +- daft/execution/physical_plan.py | 5 +- daft/execution/rust_physical_plan_shim.py | 5 +- daft/expressions/expressions.py | 2 +- daft/iceberg/iceberg_scan.py | 17 +- daft/iceberg/iceberg_write.py | 239 ++++++++++++ daft/logical/builder.py | 4 +- daft/table/partitioning.py | 65 ++++ daft/table/table_io.py | 289 ++++++--------- src/daft-core/src/array/ops/truncate.rs | 9 + src/daft-core/src/series/ops/partitioning.rs | 31 +- .../src/functions/partitioning/evaluators.rs | 11 +- src/daft-plan/src/builder.rs | 8 +- src/daft-plan/src/logical_ops/sink.rs | 27 +- src/daft-plan/src/sink_info.rs | 9 +- src/daft-scheduler/src/scheduler.rs | 14 +- tests/cookbook/test_write.py | 6 +- tests/io/iceberg/test_iceberg_writes.py | 348 ++++++++++++++++-- tests/series/test_partitioning.py | 37 +- 21 files changed, 873 insertions(+), 298 deletions(-) create mode 100644 daft/iceberg/iceberg_write.py create mode 100644 daft/table/partitioning.py diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index da4d3e6abc..d5ef5873c3 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -13,6 +13,7 @@ from daft.udf import PartialStatefulUDF, PartialStatelessUDF if TYPE_CHECKING: import pyarrow as pa + from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties @@ -1699,7 +1700,7 @@ class LogicalPlanBuilder: self, table_name: str, table_location: str, - spec_id: int, + partition_spec: IcebergPartitionSpec, iceberg_schema: IcebergSchema, iceberg_properties: IcebergTableProperties, catalog_columns: list[str], diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 9ea9fff5f9..a4e48caba2 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -647,13 +647,13 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> DataFrame: The operations that occurred with this write. """ - if len(table.spec().fields) > 0: - raise ValueError("Cannot write to partitioned Iceberg tables") - import pyarrow as pa import pyiceberg from packaging.version import parse + if len(table.spec().fields) > 0 and parse(pyiceberg.__version__) < parse("0.7.0"): + raise ValueError("pyiceberg>=0.7.0 is required to write to a partitioned table") + if parse(pyiceberg.__version__) < parse("0.6.0"): raise ValueError(f"Write Iceberg is only supported on pyiceberg>=0.6.0, found {pyiceberg.__version__}") @@ -683,12 +683,18 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> else: deleted_files = [] + schema = table.schema() + partitioning: Dict[str, list] = {schema.find_field(field.source_id).name: [] for field in table.spec().fields} + for data_file in data_files: operations.append("ADD") path.append(data_file.file_path) rows.append(data_file.record_count) size.append(data_file.file_size_in_bytes) + for field in partitioning.keys(): + partitioning[field].append(getattr(data_file.partition, field, None)) + for pf in deleted_files: data_file = pf.file operations.append("DELETE") @@ -696,6 +702,9 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> rows.append(data_file.record_count) size.append(data_file.file_size_in_bytes) + for field in partitioning.keys(): + partitioning[field].append(getattr(data_file.partition, field, None)) + if parse(pyiceberg.__version__) >= parse("0.7.0"): from pyiceberg.table import ALWAYS_TRUE, PropertyUtil, TableProperties @@ -735,19 +744,23 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> merge.commit() + with_operations = { + "operation": pa.array(operations, type=pa.string()), + "rows": pa.array(rows, type=pa.int64()), + "file_size": pa.array(size, type=pa.int64()), + "file_name": pa.array([fp for fp in path], type=pa.string()), + } + + if partitioning: + with_operations["partitioning"] = pa.StructArray.from_arrays( + partitioning.values(), names=partitioning.keys() + ) + from daft import from_pydict - with_operations = from_pydict( - { - "operation": pa.array(operations, type=pa.string()), - "rows": pa.array(rows, type=pa.int64()), - "file_size": pa.array(size, type=pa.int64()), - "file_name": pa.array([os.path.basename(fp) for fp in path], type=pa.string()), - } - ) # NOTE: We are losing the history of the plan here. # This is due to the fact that the logical plan of the write_iceberg returns datafiles but we want to return the above data - return with_operations + return from_pydict(with_operations) @DataframePublicAPI def write_deltalake( diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 95b57e9d10..7693a0a84c 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: import pathlib + from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties @@ -390,7 +391,7 @@ class WriteIceberg(SingleOutputInstruction): base_path: str iceberg_schema: IcebergSchema iceberg_properties: IcebergTableProperties - spec_id: int + partition_spec: IcebergPartitionSpec io_config: IOConfig | None def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: @@ -418,7 +419,7 @@ def _handle_file_write(self, input: MicroPartition) -> MicroPartition: base_path=self.base_path, schema=self.iceberg_schema, properties=self.iceberg_properties, - spec_id=self.spec_id, + partition_spec=self.partition_spec, io_config=self.io_config, ) diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 34da186731..273ee0dc49 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -53,6 +53,7 @@ if TYPE_CHECKING: import pathlib + from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties @@ -120,7 +121,7 @@ def iceberg_write( base_path: str, iceberg_schema: IcebergSchema, iceberg_properties: IcebergTableProperties, - spec_id: int, + partition_spec: IcebergPartitionSpec, io_config: IOConfig | None, ) -> InProgressPhysicalPlan[PartitionT]: """Write the results of `child_plan` into pyiceberg data files described by `write_info`.""" @@ -131,7 +132,7 @@ def iceberg_write( base_path=base_path, iceberg_schema=iceberg_schema, iceberg_properties=iceberg_properties, - spec_id=spec_id, + partition_spec=partition_spec, io_config=io_config, ), ) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 3c85ad4149..351d27f3bb 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -19,6 +19,7 @@ from daft.runners.partitioning import PartitionT if TYPE_CHECKING: + from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties @@ -344,7 +345,7 @@ def write_iceberg( base_path: str, iceberg_schema: IcebergSchema, iceberg_properties: IcebergTableProperties, - spec_id: int, + partition_spec: IcebergPartitionSpec, io_config: IOConfig | None, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: return physical_plan.iceberg_write( @@ -352,7 +353,7 @@ def write_iceberg( base_path=base_path, iceberg_schema=iceberg_schema, iceberg_properties=iceberg_properties, - spec_id=spec_id, + partition_spec=partition_spec, io_config=io_config, ) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index d5482220df..a2232d9a24 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -3307,7 +3307,7 @@ def days(self) -> Expression: """Partitioning Transform that returns the number of days since epoch (1970-01-01) Returns: - Expression: Date Type Expression + Expression: Int32 Expression in days """ return Expression._from_pyexpr(self._expr.partitioning_days()) diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index ee60eda888..ac47116689 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -45,10 +45,8 @@ def _iceberg_partition_field_to_daft_partition_field( source_name, DataType.from_arrow_type(schema_to_pyarrow(iceberg_schema.find_type(source_name))) ) transform = pfield.transform - iceberg_result_type = transform.result_type(source_field.field_type) - arrow_result_type = schema_to_pyarrow(iceberg_result_type) - daft_result_type = DataType.from_arrow_type(arrow_result_type) - result_field = Field.create(name, daft_result_type) + source_type = DataType.from_arrow_type(schema_to_pyarrow(source_field.field_type)) + from pyiceberg.transforms import ( BucketTransform, DayTransform, @@ -62,22 +60,33 @@ def _iceberg_partition_field_to_daft_partition_field( tfm = None if isinstance(transform, IdentityTransform): tfm = PartitionTransform.identity() + result_type = source_type elif isinstance(transform, YearTransform): tfm = PartitionTransform.year() + result_type = DataType.int32() elif isinstance(transform, MonthTransform): tfm = PartitionTransform.month() + result_type = DataType.int32() elif isinstance(transform, DayTransform): tfm = PartitionTransform.day() + # pyiceberg uses date as the result type of a day transform, which is incorrect + # so we cannot use transform.result_type() here + result_type = DataType.int32() elif isinstance(transform, HourTransform): tfm = PartitionTransform.hour() + result_type = DataType.int32() elif isinstance(transform, BucketTransform): n = transform.num_buckets tfm = PartitionTransform.iceberg_bucket(n) + result_type = DataType.int32() elif isinstance(transform, TruncateTransform): w = transform.width tfm = PartitionTransform.iceberg_truncate(w) + result_type = source_type else: warnings.warn(f"{transform} not implemented, Please make an issue!") + result_type = source_type + result_field = Field.create(name, result_type) return make_partition_field(result_field, daft_field, transform=tfm) diff --git a/daft/iceberg/iceberg_write.py b/daft/iceberg/iceberg_write.py new file mode 100644 index 0000000000..8bc4d1431b --- /dev/null +++ b/daft/iceberg/iceberg_write.py @@ -0,0 +1,239 @@ +import datetime +import uuid +import warnings +from typing import TYPE_CHECKING, Any, Iterator, List, Tuple + +from daft import Expression, col +from daft.table import MicroPartition +from daft.table.partitioning import PartitionedTable, partition_strings_to_path, partition_values_to_string + +if TYPE_CHECKING: + import pyarrow as pa + from pyiceberg.manifest import DataFile + from pyiceberg.partitioning import PartitionField as IcebergPartitionField + from pyiceberg.schema import Schema as IcebergSchema + from pyiceberg.table import TableProperties as IcebergTableProperties + from pyiceberg.typedef import Record as IcebergRecord + + +def add_missing_columns(table: MicroPartition, schema: "pa.Schema") -> MicroPartition: + """Add null values for columns in the schema that are missing from the table.""" + + import pyarrow as pa + + existing_columns = set(table.column_names()) + + columns = {} + for name in schema.names: + if name in existing_columns: + columns[name] = table.get_column(name) + else: + columns[name] = pa.nulls(len(table), type=schema.field(name).type) + + return MicroPartition.from_pydict(columns) + + +def coerce_pyarrow_table_to_schema(pa_table: "pa.Table", schema: "pa.Schema") -> "pa.Table": + """Coerces a PyArrow table to the supplied schema + + 1. For each field in `pa_table`, cast it to the field in `input_schema` if one with a matching name + is available + 2. Reorder the fields in the casted table to the supplied schema, dropping any fields in `pa_table` + that do not exist in the supplied schema + 3. If any fields in the supplied schema are not present, add a null array of the correct type + + This ensures that we populate field_id for iceberg as well as fill in null values where needed + This might break for nested fields with large_strings + we should test that behavior + + Args: + pa_table (pa.Table): Table to coerce + schema (pa.Schema): Iceberg schema to coerce to + + Returns: + pa.Table: Table with schema == `schema` + """ + import pyarrow as pa + + input_schema_names = set(schema.names) + + # Perform casting of types to provided schema's types + cast_to_schema = [ + (schema.field(inferred_field.name) if inferred_field.name in input_schema_names else inferred_field) + for inferred_field in pa_table.schema + ] + casted_table = pa_table.cast(pa.schema(cast_to_schema)) + + # Reorder and pad columns with a null column where necessary + pa_table_column_names = set(casted_table.column_names) + columns = [] + for name in schema.names: + if name in pa_table_column_names: + columns.append(casted_table[name]) + else: + columns.append(pa.nulls(len(casted_table), type=schema.field(name).type)) + return pa.table(columns, schema=schema) + + +def partition_field_to_expr(field: "IcebergPartitionField", schema: "IcebergSchema") -> Expression: + from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + YearTransform, + ) + + part_col = col(schema.find_field(field.source_id).name) + + if isinstance(field.transform, IdentityTransform): + transform_expr = part_col + elif isinstance(field.transform, YearTransform): + transform_expr = part_col.partitioning.years() + elif isinstance(field.transform, MonthTransform): + transform_expr = part_col.partitioning.months() + elif isinstance(field.transform, DayTransform): + transform_expr = part_col.partitioning.days() + elif isinstance(field.transform, HourTransform): + transform_expr = part_col.partitioning.hours() + elif isinstance(field.transform, BucketTransform): + transform_expr = part_col.partitioning.iceberg_bucket(field.transform.num_buckets) + elif isinstance(field.transform, TruncateTransform): + transform_expr = part_col.partitioning.iceberg_truncate(field.transform.width) + else: + warnings.warn(f"{field.transform} not implemented, Please make an issue!") + transform_expr = part_col + + # currently the partitioning expressions change the name of the column + # so we need to alias it back to the original column name + return transform_expr + + +def to_partition_representation(value: Any): + """ + Converts a partition value to the format expected by Iceberg metadata. + Most transforms already do this, but the identity transforms preserve the original value type so we need to convert it. + """ + + if value is None: + return None + + if isinstance(value, datetime.datetime): + # Convert to microseconds since epoch + return (value - datetime.datetime(1970, 1, 1)) // datetime.timedelta(microseconds=1) + elif isinstance(value, datetime.date): + # Convert to days since epoch + return (value - datetime.date(1970, 1, 1)) // datetime.timedelta(days=1) + elif isinstance(value, datetime.time): + # Convert to microseconds since midnight + return (value.hour * 60 * 60 + value.minute * 60 + value.second) * 1_000_000 + value.microsecond + elif isinstance(value, uuid.UUID): + return str(value) + else: + return value + + +class IcebergWriteVisitors: + class FileVisitor: + def __init__(self, parent: "IcebergWriteVisitors", partition_record: "IcebergRecord"): + self.parent = parent + self.partition_record = partition_record + + def __call__(self, written_file): + import pyiceberg + from packaging.version import parse + from pyiceberg.io.pyarrow import ( + compute_statistics_plan, + parquet_path_to_id_mapping, + ) + from pyiceberg.manifest import DataFile, DataFileContent + from pyiceberg.manifest import FileFormat as IcebergFileFormat + + file_path = f"{self.parent.protocol}://{written_file.path}" + size = written_file.size + metadata = written_file.metadata + + kwargs = { + "content": DataFileContent.DATA, + "file_path": file_path, + "file_format": IcebergFileFormat.PARQUET, + "partition": self.partition_record, + "file_size_in_bytes": size, + # After this has been fixed: + # https://github.com/apache/iceberg-python/issues/271 + # "sort_order_id": task.sort_order_id, + "sort_order_id": None, + # Just copy these from the table for now + "spec_id": self.parent.spec_id, + "equality_ids": None, + "key_metadata": None, + } + + if parse(pyiceberg.__version__) >= parse("0.7.0"): + from pyiceberg.io.pyarrow import data_file_statistics_from_parquet_metadata + + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(self.parent.schema, self.parent.properties), + parquet_column_mapping=parquet_path_to_id_mapping(self.parent.schema), + ) + + data_file = DataFile( + **{ + **kwargs, + **statistics.to_serialized_dict(), + } + ) + else: + from pyiceberg.io.pyarrow import fill_parquet_file_metadata + + data_file = DataFile(**kwargs) + + fill_parquet_file_metadata( + data_file=data_file, + parquet_metadata=metadata, + stats_columns=compute_statistics_plan(self.parent.schema, self.parent.properties), + parquet_column_mapping=parquet_path_to_id_mapping(self.parent.schema), + ) + + self.parent.data_files.append(data_file) + + def __init__(self, protocol: str, spec_id: int, schema: "IcebergSchema", properties: "IcebergTableProperties"): + self.data_files: List[DataFile] = [] + self.protocol = protocol + self.spec_id = spec_id + self.schema = schema + self.properties = properties + + def visitor(self, partition_record: "IcebergRecord") -> "IcebergWriteVisitors.FileVisitor": + return self.FileVisitor(self, partition_record) + + def to_metadata(self) -> MicroPartition: + return MicroPartition.from_pydict({"data_file": self.data_files}) + + +def partitioned_table_to_iceberg_iter( + partitioned: PartitionedTable, root_path: str, schema: "pa.Schema" +) -> Iterator[Tuple["pa.Table", str, "IcebergRecord"]]: + from pyiceberg.typedef import Record as IcebergRecord + + partition_values = partitioned.partition_values() + + if partition_values: + partition_strings = partition_values_to_string(partition_values, partition_null_fallback="null").to_pylist() + partition_values_list = partition_values.to_pylist() + + for table, part_vals, part_strs in zip(partitioned.partitions(), partition_values_list, partition_strings): + iceberg_part_vals = {k: to_partition_representation(v) for k, v in part_vals.items()} + part_record = IcebergRecord(**iceberg_part_vals) + part_path = partition_strings_to_path(root_path, part_strs) + + arrow_table = coerce_pyarrow_table_to_schema(table.to_arrow(), schema) + + yield arrow_table, part_path, part_record + else: + arrow_table = coerce_pyarrow_table_to_schema(partitioned.table.to_arrow(), schema) + + yield arrow_table, root_path, IcebergRecord() diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 218e521d6d..31b347295d 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -289,12 +289,12 @@ def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder: name = ".".join(table.name()) location = f"{table.location()}/data" - spec_id = table.spec().spec_id + partition_spec = table.spec() schema = table.schema() props = table.properties columns = [col.name for col in schema.columns] io_config = _convert_iceberg_file_io_properties_to_io_config(table.io.properties) - builder = self._builder.iceberg_write(name, location, spec_id, schema, props, columns, io_config) + builder = self._builder.iceberg_write(name, location, partition_spec, schema, props, columns, io_config) return LogicalPlanBuilder(builder) def write_deltalake( diff --git a/daft/table/partitioning.py b/daft/table/partitioning.py new file mode 100644 index 0000000000..fac841d346 --- /dev/null +++ b/daft/table/partitioning.py @@ -0,0 +1,65 @@ +from typing import Dict, List, Optional + +from daft.expressions import ExpressionsProjection +from daft.series import Series + +from .micropartition import MicroPartition + + +def partition_strings_to_path(root_path: str, parts: Dict[str, str]): + postfix = "/".join(f"{key}={value}" for key, value in parts.items()) + return f"{root_path}/{postfix}" + + +def partition_values_to_string( + partition_values: MicroPartition, partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__" +) -> MicroPartition: + """Convert partition values to human-readable string representation, filling nulls with `partition_null_fallback`.""" + default_part = Series.from_pylist([partition_null_fallback]) + pkey_names = partition_values.column_names() + + partition_strings = {} + + for c in pkey_names: + column = partition_values.get_column(c) + string_names = column._to_str_values() + null_filled = column.is_null().if_else(default_part, string_names) + partition_strings[c] = null_filled.to_pylist() + + return MicroPartition.from_pydict(partition_strings) + + +class PartitionedTable: + def __init__(self, table: MicroPartition, partition_keys: Optional[ExpressionsProjection]): + self.table = table + self.partition_keys = partition_keys + self._partitions = None + self._partition_values = None + + def _create_partitions(self): + if self.partition_keys is None or len(self.partition_keys) == 0: + self._partitions = [self.table] + self._partition_values = None + else: + self._partitions, self._partition_values = self.table.partition_by_value(partition_keys=self.partition_keys) + + def partitions(self) -> List[MicroPartition]: + """ + Returns a list of MicroPartitions representing the table partitioned by the partition keys. + + If the table is not partitioned, returns the original table as the single element in the list. + """ + if self._partitions is None: + self._create_partitions() + return self._partitions # type: ignore + + def partition_values(self) -> Optional[MicroPartition]: + """ + Returns the partition values, with each row corresponding to the partition at the same index in PartitionedTable.partitions(). + + If the table is not partitioned, returns None. + + """ + if self._partition_values is None: + self._create_partitions() + return self._partition_values diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 97c274a1f3..ee366946b2 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -6,7 +6,7 @@ import random import time from functools import partial -from typing import IO, TYPE_CHECKING, Any, Union +from typing import IO, TYPE_CHECKING, Any, Iterator, Union from uuid import uuid4 from daft.context import get_context @@ -23,7 +23,6 @@ PythonStorageConfig, StorageConfig, ) -from daft.datatype import DataType from daft.dependencies import pa, pacsv, pads, pajson, pq from daft.expressions import ExpressionsProjection from daft.filesystem import ( @@ -38,13 +37,17 @@ TableReadOptions, ) from daft.series import Series -from daft.table import MicroPartition +from daft.sql.sql_connection import SQLConnection + +from .micropartition import MicroPartition +from .partitioning import PartitionedTable, partition_strings_to_path, partition_values_to_string FileInput = Union[pathlib.Path, str, IO[bytes]] if TYPE_CHECKING: from collections.abc import Callable, Generator + from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties @@ -397,6 +400,53 @@ def read_csv( return _cast_table_to_schema(daft_table, read_options=read_options, schema=schema) +def partitioned_table_to_hive_iter(partitioned: PartitionedTable, root_path: str) -> Iterator[tuple[pa.Table, str]]: + partition_values = partitioned.partition_values() + + if partition_values: + partition_strings = partition_values_to_string(partition_values).to_pylist() + + for part_table, part_strs in zip(partitioned.partitions(), partition_strings): + part_path = partition_strings_to_path(root_path, part_strs) + arrow_table = part_table.to_arrow() + + yield arrow_table, part_path + else: + yield partitioned.table.to_arrow(), root_path + + +class TabularWriteVisitors: + class FileVisitor: + def __init__(self, parent: TabularWriteVisitors, idx: int): + self.parent = parent + self.idx = idx + + def __call__(self, written_file): + self.parent.paths.append(written_file.path) + self.parent.partition_indices.append(self.idx) + + def __init__(self, partition_values: MicroPartition | None, path_key: str = "path"): + self.paths: list[str] = [] + self.partition_indices: list[int] = [] + self.partition_values = partition_values + self.path_key = path_key + + def visitor(self, partition_idx: int) -> TabularWriteVisitors.FileVisitor: + return self.FileVisitor(self, partition_idx) + + def to_metadata(self) -> MicroPartition: + metadata: dict[str, Any] = {self.path_key: self.paths} + + if self.partition_values: + partition_indices = Series.from_pylist(self.partition_indices) + partition_values_for_paths = self.partition_values.take(partition_indices) + + for c in partition_values_for_paths.column_names(): + metadata[c] = partition_values_for_paths.get_column(c) + + return MicroPartition.from_pydict(metadata) + + def write_tabular( table: MicroPartition, file_format: FileFormat, @@ -405,7 +455,6 @@ def write_tabular( partition_cols: ExpressionsProjection | None = None, compression: str | None = None, io_config: IOConfig | None = None, - partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__", ) -> MicroPartition: [resolved_path], fs = _resolve_paths_and_filesystem(path, io_config=io_config) if isinstance(path, pathlib.Path): @@ -418,35 +467,6 @@ def write_tabular( is_local_fs = canonicalized_protocol == "file" - tables_to_write: list[MicroPartition] - part_keys_postfix_per_table: list[str | None] - partition_values = None - if partition_cols and len(partition_cols) > 0: - default_part = Series.from_pylist([partition_null_fallback]) - split_tables, partition_values = table.partition_by_value(partition_keys=partition_cols) - assert len(split_tables) == len(partition_values) - pkey_names = partition_values.column_names() - - values_string_values = [] - - for c in pkey_names: - column = partition_values.get_column(c) - string_names = column._to_str_values() - null_filled = column.is_null().if_else(default_part, string_names) - values_string_values.append(null_filled.to_pylist()) - - part_keys_postfix_per_table = [] - for i in range(len(partition_values)): - postfix = "/".join(f"{pkey}={values[i]}" for pkey, values in zip(pkey_names, values_string_values)) - part_keys_postfix_per_table.append(postfix) - tables_to_write = split_tables - else: - tables_to_write = [table] - part_keys_postfix_per_table = [None] - - visited_paths = [] - partition_idx = [] - execution_config = get_context().daft_execution_config TARGET_ROW_GROUP_SIZE = execution_config.parquet_target_row_group_size @@ -465,107 +485,56 @@ def write_tabular( else: raise ValueError(f"Unsupported file format {file_format}") - for i, (tab, pf) in enumerate(zip(tables_to_write, part_keys_postfix_per_table)): - full_path = resolved_path - if pf is not None and len(pf) > 0: - full_path = f"{full_path}/{pf}" + partitioned = PartitionedTable(table, partition_cols) + + # I kept this from our original code, but idk why it's the first column name -kevin + path_key = schema.column_names()[0] - arrow_table = tab.to_arrow() + visitors = TabularWriteVisitors(partitioned.partition_values(), path_key) - size_bytes = arrow_table.nbytes + for i, (part_table, part_path) in enumerate(partitioned_table_to_hive_iter(partitioned, resolved_path)): + size_bytes = part_table.nbytes target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) - num_rows = len(arrow_table) + num_rows = len(part_table) rows_per_file = max(math.ceil(num_rows / target_num_files), 1) target_row_groups = max(math.ceil(size_bytes / TARGET_ROW_GROUP_SIZE / inflation_factor), 1) rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) - def file_visitor(written_file, i=i): - visited_paths.append(written_file.path) - partition_idx.append(i) - _write_tabular_arrow_table( - arrow_table=arrow_table, - schema=arrow_table.schema, - full_path=full_path, + arrow_table=part_table, + schema=part_table.schema, + full_path=part_path, format=format, opts=opts, fs=fs, rows_per_file=rows_per_file, rows_per_row_group=rows_per_row_group, create_dir=is_local_fs, - file_visitor=file_visitor, + file_visitor=visitors.visitor(i), ) - data_dict: dict[str, Any] = { - schema.column_names()[0]: Series.from_pylist(visited_paths, name=schema.column_names()[0]).cast( - DataType.string() - ) - } - - if partition_values is not None: - partition_idx_series = Series.from_pylist(partition_idx).cast(DataType.int64()) - for c_name in partition_values.column_names(): - data_dict[c_name] = partition_values.get_column(c_name).take(partition_idx_series) - return MicroPartition.from_pydict(data_dict) - - -def coerce_pyarrow_table_to_schema(pa_table: pa.Table, input_schema: pa.Schema) -> pa.Table: - """Coerces a PyArrow table to the supplied schema - - 1. For each field in `pa_table`, cast it to the field in `input_schema` if one with a matching name - is available - 2. Reorder the fields in the casted table to the supplied schema, dropping any fields in `pa_table` - that do not exist in the supplied schema - 3. If any fields in the supplied schema are not present, add a null array of the correct type - - Args: - pa_table (pa.Table): Table to coerce - input_schema (pa.Schema): Schema to coerce to - - Returns: - pa.Table: Table with schema == `input_schema` - """ - input_schema_names = set(input_schema.names) - - # Perform casting of types to provided schema's types - cast_to_schema = [ - (input_schema.field(inferred_field.name) if inferred_field.name in input_schema_names else inferred_field) - for inferred_field in pa_table.schema - ] - casted_table = pa_table.cast(pa.schema(cast_to_schema)) - - # Reorder and pad columns with a null column where necessary - pa_table_column_names = set(casted_table.column_names) - columns = [] - for name in input_schema.names: - if name in pa_table_column_names: - columns.append(casted_table[name]) - else: - columns.append(pa.nulls(len(casted_table), type=input_schema.field(name).type)) - return pa.table(columns, schema=input_schema) + return visitors.to_metadata() def write_iceberg( - mp: MicroPartition, + table: MicroPartition, base_path: str, schema: IcebergSchema, properties: IcebergTableProperties, - spec_id: int | None, + partition_spec: IcebergPartitionSpec, io_config: IOConfig | None = None, ): - import pyiceberg - from packaging.version import parse - from pyiceberg.io.pyarrow import ( - compute_statistics_plan, - parquet_path_to_id_mapping, - schema_to_pyarrow, + from pyiceberg.io.pyarrow import schema_to_pyarrow + + from daft.iceberg.iceberg_write import ( + IcebergWriteVisitors, + add_missing_columns, + partition_field_to_expr, + partitioned_table_to_iceberg_iter, ) - from pyiceberg.manifest import DataFile, DataFileContent - from pyiceberg.manifest import FileFormat as IcebergFileFormat - from pyiceberg.typedef import Record [resolved_path], fs = _resolve_paths_and_filesystem(base_path, io_config=io_config) if isinstance(base_path, pathlib.Path): @@ -576,58 +545,6 @@ def write_iceberg( protocol = get_protocol_from_path(path_str) canonicalized_protocol = canonicalize_protocol(protocol) - data_files = [] - - def file_visitor(written_file, protocol=protocol): - file_path = f"{protocol}://{written_file.path}" - size = written_file.size - metadata = written_file.metadata - - kwargs = { - "content": DataFileContent.DATA, - "file_path": file_path, - "file_format": IcebergFileFormat.PARQUET, - "partition": Record(), - "file_size_in_bytes": size, - # After this has been fixed: - # https://github.com/apache/iceberg-python/issues/271 - # "sort_order_id": task.sort_order_id, - "sort_order_id": None, - # Just copy these from the table for now - "spec_id": spec_id, - "equality_ids": None, - "key_metadata": None, - } - - if parse(pyiceberg.__version__) >= parse("0.7.0"): - from pyiceberg.io.pyarrow import data_file_statistics_from_parquet_metadata - - statistics = data_file_statistics_from_parquet_metadata( - parquet_metadata=metadata, - stats_columns=compute_statistics_plan(schema, properties), - parquet_column_mapping=parquet_path_to_id_mapping(schema), - ) - - data_file = DataFile( - **{ - **kwargs, - **statistics.to_serialized_dict(), - } - ) - else: - from pyiceberg.io.pyarrow import fill_parquet_file_metadata - - data_file = DataFile(**kwargs) - - fill_parquet_file_metadata( - data_file=data_file, - parquet_metadata=metadata, - stats_columns=compute_statistics_plan(schema, properties), - parquet_column_mapping=parquet_path_to_id_mapping(schema), - ) - - data_files.append(data_file) - is_local_fs = canonicalized_protocol == "file" execution_config = get_context().daft_execution_config @@ -637,43 +554,45 @@ def file_visitor(written_file, protocol=protocol): target_file_size = 512 * 1024 * 1024 TARGET_ROW_GROUP_SIZE = 128 * 1024 * 1024 - arrow_table = mp.to_arrow() + format = pads.ParquetFileFormat() - file_schema = schema_to_pyarrow(schema) + opts = format.make_write_options(compression="zstd", use_compliant_nested_type=False) - # This ensures that we populate field_id for iceberg as well as fill in null values where needed - # This might break for nested fields with large_strings - # we should test that behavior - arrow_table = coerce_pyarrow_table_to_schema(arrow_table, file_schema) + file_schema = schema_to_pyarrow(schema) - size_bytes = arrow_table.nbytes + partition_keys = ExpressionsProjection([partition_field_to_expr(field, schema) for field in partition_spec.fields]) - target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) - num_rows = len(arrow_table) + table = add_missing_columns(table, file_schema) + partitioned = PartitionedTable(table, partition_keys) + visitors = IcebergWriteVisitors(protocol, partition_spec.spec_id, schema, properties) - rows_per_file = max(math.ceil(num_rows / target_num_files), 1) + for part_table, part_path, part_record in partitioned_table_to_iceberg_iter( + partitioned, resolved_path, file_schema + ): + size_bytes = part_table.nbytes - target_row_groups = max(math.ceil(size_bytes / TARGET_ROW_GROUP_SIZE / inflation_factor), 1) - rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) + target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) + num_rows = len(part_table) - format = pads.ParquetFileFormat() + rows_per_file = max(math.ceil(num_rows / target_num_files), 1) - opts = format.make_write_options(compression="zstd", use_compliant_nested_type=False) + target_row_groups = max(math.ceil(size_bytes / TARGET_ROW_GROUP_SIZE / inflation_factor), 1) + rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) - _write_tabular_arrow_table( - arrow_table=arrow_table, - schema=file_schema, - full_path=resolved_path, - format=format, - opts=opts, - fs=fs, - rows_per_file=rows_per_file, - rows_per_row_group=rows_per_row_group, - create_dir=is_local_fs, - file_visitor=file_visitor, - ) + _write_tabular_arrow_table( + arrow_table=part_table, + schema=file_schema, + full_path=part_path, + format=format, + opts=opts, + fs=fs, + rows_per_file=rows_per_file, + rows_per_row_group=rows_per_row_group, + create_dir=is_local_fs, + file_visitor=visitors.visitor(part_record), + ) - return MicroPartition.from_pydict({"data_file": Series.from_pylist(data_files, name="data_file", pyobj="force")}) + return visitors.to_metadata() def write_deltalake( diff --git a/src/daft-core/src/array/ops/truncate.rs b/src/daft-core/src/array/ops/truncate.rs index 3c1eff33b6..83a5c1b0b4 100644 --- a/src/daft-core/src/array/ops/truncate.rs +++ b/src/daft-core/src/array/ops/truncate.rs @@ -10,6 +10,7 @@ use crate::{ logical::Decimal128Array, DaftNumericType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Array, }, + prelude::BinaryArray, }; macro_rules! impl_int_truncate { @@ -67,3 +68,11 @@ impl Utf8Array { Ok(Utf8Array::from((self.name(), Box::new(substring)))) } } + +impl BinaryArray { + pub fn iceberg_truncate(&self, w: i64) -> DaftResult { + let as_arrow = self.as_arrow(); + let substring = arrow2::compute::substring::binary_substring(as_arrow, 0, &Some(w)); + Ok(BinaryArray::from((self.name(), Box::new(substring)))) + } +} diff --git a/src/daft-core/src/series/ops/partitioning.rs b/src/daft-core/src/series/ops/partitioning.rs index 73b942ca00..692316bbda 100644 --- a/src/daft-core/src/series/ops/partitioning.rs +++ b/src/daft-core/src/series/ops/partitioning.rs @@ -1,11 +1,6 @@ use common_error::{DaftError, DaftResult}; -use crate::{ - array::ops::as_arrow::AsArrow, - datatypes::{logical::TimestampArray, DataType, Int32Array, Int64Array, TimeUnit}, - series::{array_impl::IntoSeries, Series}, - with_match_integer_daft_types, -}; +use crate::{array::ops::as_arrow::AsArrow, prelude::*, with_match_integer_daft_types}; impl Series { pub fn partitioning_years(&self) -> DaftResult { @@ -47,7 +42,7 @@ impl Series { ((&years_since_1970 * &months_in_year)? + months_of_this_year)? - month_of_epoch } _ => Err(DaftError::ComputeError(format!( - "Can only run partitioning_years() operation on temporal types, got {}", + "Can only run partitioning_months() operation on temporal types, got {}", self.data_type() ))), }?; @@ -57,26 +52,23 @@ impl Series { } pub fn partitioning_days(&self) -> DaftResult { - let result = match self.data_type() { - DataType::Date => Ok(self.clone()), - DataType::Timestamp(_, None) => { - let ts_array = self.downcast::()?; - Ok(ts_array.date()?.into_series()) + let value = match self.data_type() { + DataType::Date => { + let date_array = self.downcast::()?; + Ok(date_array.physical.clone().into_series()) } - - DataType::Timestamp(tu, Some(_)) => { + DataType::Timestamp(tu, _) => { let array = self.cast(&DataType::Timestamp(*tu, None))?; let ts_array = array.downcast::()?; - Ok(ts_array.date()?.into_series()) + Ok(ts_array.date()?.physical.into_series()) } - _ => Err(DaftError::ComputeError(format!( "Can only run partitioning_days() operation on temporal types, got {}", self.data_type() ))), }?; - Ok(result.rename(format!("{}_days", self.name()))) + Ok(value.rename(format!("{}_days", self.name()))) } pub fn partitioning_hours(&self) -> DaftResult { @@ -126,12 +118,13 @@ impl Series { } DataType::Decimal128(..) => Ok(self.decimal128()?.iceberg_truncate(w)?.into_series()), DataType::Utf8 => Ok(self.utf8()?.iceberg_truncate(w)?.into_series()), + DataType::Binary => Ok(self.binary()?.iceberg_truncate(w)?.into_series()), _ => Err(DaftError::ComputeError(format!( - "Can only run partitioning_iceberg_truncate() operation on integers, decimal and string, got {}", + "Can only run partitioning_iceberg_truncate() operation on integers, decimal, string, and binary, got {}", self.data_type() ))), }?; - Ok(trunc.rename(format!("{}_truncate", self.name()))) + Ok(trunc.rename(format!("{}_trunc", self.name()))) } } diff --git a/src/daft-dsl/src/functions/partitioning/evaluators.rs b/src/daft-dsl/src/functions/partitioning/evaluators.rs index a0622b40a4..1ec5a4ab7f 100644 --- a/src/daft-dsl/src/functions/partitioning/evaluators.rs +++ b/src/daft-dsl/src/functions/partitioning/evaluators.rs @@ -52,12 +52,13 @@ macro_rules! impl_func_evaluator_for_partitioning { } }; } -use DataType::{Date, Int32}; + +use DataType::Int32; use crate::functions::FunctionExpr; impl_func_evaluator_for_partitioning!(YearsEvaluator, years, partitioning_years, Int32); impl_func_evaluator_for_partitioning!(MonthsEvaluator, months, partitioning_months, Int32); -impl_func_evaluator_for_partitioning!(DaysEvaluator, days, partitioning_days, Date); +impl_func_evaluator_for_partitioning!(DaysEvaluator, days, partitioning_days, Int32); impl_func_evaluator_for_partitioning!(HoursEvaluator, hours, partitioning_hours, Int32); pub(super) struct IcebergBucketEvaluator {} @@ -125,10 +126,10 @@ impl FunctionEvaluator for IcebergTruncateEvaluator { [input] => match input.to_field(schema) { Ok(field) => match &field.dtype { DataType::Decimal128(_, _) - | DataType::Utf8 => Ok(Field::new(format!("{}_truncate", field.name), field.dtype)), - v if v.is_integer() => Ok(Field::new(format!("{}_truncate", field.name), field.dtype)), + | DataType::Utf8 | DataType::Binary => Ok(Field::new(format!("{}_trunc", field.name), field.dtype)), + v if v.is_integer() => Ok(Field::new(format!("{}_trunc", field.name), field.dtype)), _ => Err(DaftError::TypeError(format!( - "Expected input to IcebergTruncate to be an Integer, Utf8 or Decimal, got {}", + "Expected input to IcebergTruncate to be an Integer, Utf8, Decimal, or Binary, got {}", field.dtype ))), }, diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 2c6b599303..0098d72405 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -400,7 +400,7 @@ impl LogicalPlanBuilder { &self, table_name: String, table_location: String, - spec_id: i64, + partition_spec: PyObject, iceberg_schema: PyObject, iceberg_properties: PyObject, io_config: Option, @@ -410,7 +410,7 @@ impl LogicalPlanBuilder { catalog: crate::sink_info::CatalogType::Iceberg(IcebergCatalogInfo { table_name, table_location, - spec_id, + partition_spec, iceberg_schema, iceberg_properties, io_config, @@ -724,7 +724,7 @@ impl PyLogicalPlanBuilder { &self, table_name: String, table_location: String, - spec_id: i64, + partition_spec: PyObject, iceberg_schema: PyObject, iceberg_properties: PyObject, catalog_columns: Vec, @@ -735,7 +735,7 @@ impl PyLogicalPlanBuilder { .iceberg_write( table_name, table_location, - spec_id, + partition_spec, iceberg_schema, iceberg_properties, io_config.map(|cfg| cfg.config), diff --git a/src/daft-plan/src/logical_ops/sink.rs b/src/daft-plan/src/logical_ops/sink.rs index 69f331fcb7..2a23292c44 100644 --- a/src/daft-plan/src/logical_ops/sink.rs +++ b/src/daft-plan/src/logical_ops/sink.rs @@ -4,6 +4,8 @@ use common_error::DaftResult; use daft_core::prelude::*; use daft_dsl::resolve_exprs; +#[cfg(feature = "python")] +use crate::sink_info::CatalogType; use crate::{sink_info::SinkInfo, LogicalPlan, OutputFileInfo}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -42,7 +44,8 @@ impl Sink { io_config: io_config.clone(), })) } - _ => sink_info, + #[cfg(feature = "python")] + SinkInfo::CatalogInfo(_) => sink_info, }; let fields = match sink_info.as_ref() { @@ -56,11 +59,17 @@ impl Sink { fields } #[cfg(feature = "python")] - SinkInfo::CatalogInfo(..) => { - vec![ - // We have to return datafile since PyIceberg Table is not picklable yet - Field::new("data_file", DataType::Python), - ] + SinkInfo::CatalogInfo(catalog_info) => { + match catalog_info.catalog { + CatalogType::Iceberg(_) => { + vec![ + // We have to return datafile since PyIceberg Table is not picklable yet + Field::new("data_file", DataType::Python), + ] + } + CatalogType::DeltaLake(_) => vec![Field::new("data_file", DataType::Python)], + CatalogType::Lance(_) => vec![Field::new("fragments", DataType::Python)], + } } }; let schema = Schema::new(fields)?.into(); @@ -81,15 +90,15 @@ impl Sink { } #[cfg(feature = "python")] SinkInfo::CatalogInfo(catalog_info) => match &catalog_info.catalog { - crate::sink_info::CatalogType::Iceberg(iceberg_info) => { + CatalogType::Iceberg(iceberg_info) => { res.push(format!("Sink: Iceberg({})", iceberg_info.table_name)); res.extend(iceberg_info.multiline_display()); } - crate::sink_info::CatalogType::DeltaLake(deltalake_info) => { + CatalogType::DeltaLake(deltalake_info) => { res.push(format!("Sink: DeltaLake({})", deltalake_info.path)); res.extend(deltalake_info.multiline_display()); } - crate::sink_info::CatalogType::Lance(lance_info) => { + CatalogType::Lance(lance_info) => { res.push(format!("Sink: Lance({})", lance_info.path)); res.extend(lance_info.multiline_display()); } diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index e2c4c374e4..b66217d8d2 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -47,7 +47,11 @@ pub enum CatalogType { pub struct IcebergCatalogInfo { pub table_name: String, pub table_location: String, - pub spec_id: i64, + #[serde( + serialize_with = "serialize_py_object", + deserialize_with = "deserialize_py_object" + )] + pub partition_spec: PyObject, #[serde( serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" @@ -67,7 +71,6 @@ impl PartialEq for IcebergCatalogInfo { fn eq(&self, other: &Self) -> bool { self.table_name == other.table_name && self.table_location == other.table_location - && self.spec_id == other.spec_id && self.io_config == other.io_config } } @@ -79,7 +82,6 @@ impl Hash for IcebergCatalogInfo { fn hash(&self, state: &mut H) { self.table_name.hash(state); self.table_location.hash(state); - self.spec_id.hash(state); self.io_config.hash(state); } } @@ -90,7 +92,6 @@ impl IcebergCatalogInfo { let mut res = vec![]; res.push(format!("Table Name = {}", self.table_name)); res.push(format!("Table Location = {}", self.table_location)); - res.push(format!("Spec ID = {}", self.spec_id)); match &self.io_config { None => res.push("IOConfig = None".to_string()), Some(io_config) => res.push(format!("IOConfig = {}", io_config)), diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index f215c8b7cb..2eb66781d9 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -130,6 +130,11 @@ impl PartitionIterator { } } +#[cfg(feature = "python")] +fn exprs_to_pyexprs(exprs: &[ExprRef]) -> Vec { + exprs.iter().map(|e| e.clone().into()).collect() +} + #[allow(clippy::too_many_arguments)] #[cfg(feature = "python")] fn tabular_write( @@ -142,11 +147,6 @@ fn tabular_write( partition_cols: &Option>, io_config: &Option, ) -> PyResult { - let part_cols = partition_cols.as_ref().map(|cols| { - cols.iter() - .map(|e| e.clone().into()) - .collect::>() - }); let py_iter = py .import_bound(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? .getattr(pyo3::intern!(py, "write_file"))? @@ -156,7 +156,7 @@ fn tabular_write( PySchema::from(schema.clone()), root_dir, compression.clone(), - part_cols, + partition_cols.as_ref().map(|cols| exprs_to_pyexprs(cols)), io_config .as_ref() .map(|cfg| common_io_config::python::IOConfig { @@ -181,7 +181,7 @@ fn iceberg_write( &iceberg_info.table_location, &iceberg_info.iceberg_schema, &iceberg_info.iceberg_properties, - iceberg_info.spec_id, + &iceberg_info.partition_spec, iceberg_info .io_config .as_ref() diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index c209af7df0..46db61d47e 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import date, datetime +from datetime import datetime import pyarrow as pa import pytest @@ -97,7 +97,7 @@ def test_parquet_write_with_partitioning_readback_values(tmp_path): ( daft.col("date").partitioning.days(), "date_days", - [date(2024, 1, 1), date(2024, 2, 1), date(2024, 3, 1), date(2024, 4, 1), date(2024, 5, 1)], + [19723, 19754, 19783, 19814, 19844], ), (daft.col("date").partitioning.hours(), "date_hours", [473352, 474096, 474792, 475536, 476256]), (daft.col("date").partitioning.months(), "date_months", [648, 649, 650, 651, 652]), @@ -128,7 +128,7 @@ def test_parquet_write_with_iceberg_date_partitioning(exp, key, answer, tmp_path "exp,key,answer", [ (daft.col("id").partitioning.iceberg_bucket(10), "id_bucket", [0, 3, 5, 6, 8]), - (daft.col("id").partitioning.iceberg_truncate(10), "id_truncate", [0, 10, 20, 40]), + (daft.col("id").partitioning.iceberg_truncate(10), "id_trunc", [0, 10, 20, 40]), ], ) def test_parquet_write_with_iceberg_bucket_and_trunc(exp, key, answer, tmp_path): diff --git a/tests/io/iceberg/test_iceberg_writes.py b/tests/io/iceberg/test_iceberg_writes.py index f15247953c..aab68a0c5c 100644 --- a/tests/io/iceberg/test_iceberg_writes.py +++ b/tests/io/iceberg/test_iceberg_writes.py @@ -1,5 +1,8 @@ from __future__ import annotations +import datetime +import decimal + import pyarrow as pa import pytest @@ -18,6 +21,31 @@ from pyiceberg.catalog.sql import SqlCatalog +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + YearTransform, +) +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + ListType, + LongType, + MapType, + NestedField, + StringType, + StructType, + TimestampType, +) import daft @@ -35,69 +63,113 @@ def local_catalog(tmpdir): return catalog -def test_read_after_write_append(local_catalog): +@pytest.fixture( + scope="function", + params=[ + pytest.param((UNPARTITIONED_PARTITION_SPEC, 1), id="unpartitioned"), + pytest.param( + (PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="x")), 5), + id="identity_partitioned", + ), + pytest.param( + (PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(4), name="x")), 3), + id="bucket_partitioned", + ), + pytest.param( + (PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(2), name="x")), 3), + id="truncate_partitioned", + ), + ], +) +def simple_local_table(request, local_catalog): + partition_spec, num_partitions = request.param + + schema = Schema( + NestedField(field_id=1, name="x", type=LongType()), + ) + + table = local_catalog.create_table("default.test", schema, partition_spec=partition_spec) + return table, num_partitions + + +def test_read_after_write_append(simple_local_table): + table, num_partitions = simple_local_table + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) as_arrow = df.to_arrow() - table = local_catalog.create_table("default.test", as_arrow.schema) result = df.write_iceberg(table) as_dict = result.to_pydict() - assert as_dict["operation"] == ["ADD"] - assert as_dict["rows"] == [5] + assert all(op == "ADD" for op in as_dict["operation"]), as_dict["operation"] + assert sum(as_dict["rows"]) == 5, as_dict["rows"] + assert len(as_dict["operation"]) == num_partitions read_back = daft.read_iceberg(table) - assert as_arrow == read_back.to_arrow() + assert as_arrow == read_back.to_arrow().sort_by("x") + +def test_read_after_write_overwrite(simple_local_table): + table, num_partitions = simple_local_table -def test_read_after_write_overwrite(local_catalog): df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) as_arrow = df.to_arrow() - table = local_catalog.create_table("default.test", as_arrow.schema) result = df.write_iceberg(table) as_dict = result.to_pydict() - assert as_dict["operation"] == ["ADD"] - assert as_dict["rows"] == [5] + assert all(op == "ADD" for op in as_dict["operation"]), as_dict["operation"] + assert sum(as_dict["rows"]) == 5, as_dict["rows"] + assert len(as_dict["operation"]) == num_partitions # write again (in append) result = df.write_iceberg(table) as_dict = result.to_pydict() - assert as_dict["operation"] == ["ADD"] - assert as_dict["rows"] == [5] + assert all(op == "ADD" for op in as_dict["operation"]), as_dict["operation"] + assert sum(as_dict["rows"]) == 5, as_dict["rows"] + assert len(as_dict["operation"]) == num_partitions read_back = daft.read_iceberg(table) - assert pa.concat_tables([as_arrow, as_arrow]) == read_back.to_arrow() + assert pa.concat_tables([as_arrow, as_arrow]).sort_by("x") == read_back.to_arrow().sort_by("x") # write again (in overwrite) result = df.write_iceberg(table, mode="overwrite") as_dict = result.to_pydict() - assert as_dict["operation"] == ["ADD", "DELETE", "DELETE"] - assert as_dict["rows"] == [5, 5, 5] + assert len(as_dict["operation"]) == 3 * num_partitions + assert all(op == "ADD" for op in as_dict["operation"][:num_partitions]), as_dict["operation"][:num_partitions] + assert sum(as_dict["rows"][:num_partitions]) == 5, as_dict["rows"][:num_partitions] + assert all(op == "DELETE" for op in as_dict["operation"][num_partitions:]), as_dict["operation"][num_partitions:] + assert sum(as_dict["rows"][num_partitions : 2 * num_partitions]) == 5, as_dict["rows"][ + num_partitions : 2 * num_partitions + ] + assert sum(as_dict["rows"][2 * num_partitions :]) == 5, as_dict["rows"][2 * num_partitions :] read_back = daft.read_iceberg(table) - assert as_arrow == read_back.to_arrow() + assert as_arrow == read_back.to_arrow().sort_by("x") -def test_read_and_overwrite(local_catalog): +def test_read_and_overwrite(simple_local_table): + table, num_partitions = simple_local_table + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) - as_arrow = df.to_arrow() - table = local_catalog.create_table("default.test", as_arrow.schema) result = df.write_iceberg(table) as_dict = result.to_pydict() - assert as_dict["operation"] == ["ADD"] - assert as_dict["rows"] == [5] + assert all(op == "ADD" for op in as_dict["operation"]), as_dict["operation"] + assert sum(as_dict["rows"]) == 5, as_dict["rows"] + assert len(as_dict["operation"]) == num_partitions - df = daft.read_iceberg(table).with_column("x", daft.col("x") + 1) + df = daft.from_pydict({"x": [1, 1, 1, 1, 1]}) result = df.write_iceberg(table, mode="overwrite") as_dict = result.to_pydict() - assert as_dict["operation"] == ["ADD", "DELETE"] - assert as_dict["rows"] == [5, 5] + assert len(as_dict["operation"]) == num_partitions + 1 + assert as_dict["operation"][0] == "ADD" + assert as_dict["rows"][0] == 5 + assert all(op == "DELETE" for op in as_dict["operation"][1:]), as_dict["operation"][1:] + assert sum(as_dict["rows"][1:]) == 5, as_dict["rows"][1:] read_back = daft.read_iceberg(table) - assert daft.from_pydict({"x": [2, 3, 4, 5, 6]}).to_arrow() == read_back.to_arrow() + assert df.to_arrow() == read_back.to_arrow().sort_by("x") -def test_missing_columns_write(local_catalog): +def test_missing_columns_write(simple_local_table): + table, num_partitions = simple_local_table + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) - as_arrow = df.to_arrow() - table = local_catalog.create_table("default.test", as_arrow.schema) df = daft.from_pydict({"y": [1, 2, 3, 4, 5]}) result = df.write_iceberg(table) @@ -108,21 +180,22 @@ def test_missing_columns_write(local_catalog): assert read_back.to_pydict() == {"x": [None] * 5} -def test_too_many_columns_write(local_catalog): +def test_too_many_columns_write(simple_local_table): + table, num_partitions = simple_local_table + df = daft.from_pydict({"x": [1, 2, 3, 4, 5]}) as_arrow = df.to_arrow() - table = local_catalog.create_table("default.test", as_arrow.schema) df = daft.from_pydict({"x": [1, 2, 3, 4, 5], "y": [6, 7, 8, 9, 10]}) result = df.write_iceberg(table) as_dict = result.to_pydict() - assert as_dict["operation"] == ["ADD"] - assert as_dict["rows"] == [5] + assert len(as_dict["operation"]) == num_partitions + assert all(op == "ADD" for op in as_dict["operation"]), as_dict["operation"] + assert sum(as_dict["rows"]) == 5, as_dict["rows"] read_back = daft.read_iceberg(table) - assert as_arrow == read_back.to_arrow() + assert as_arrow == read_back.to_arrow().sort_by("x") -@pytest.mark.skip def test_read_after_write_nested_fields(local_catalog): # We need to cast Large Types such as LargeList and LargeString to the i32 variants df = daft.from_pydict({"x": [["a", "b"], ["c", "d", "e"]]}) @@ -134,3 +207,212 @@ def test_read_after_write_nested_fields(local_catalog): assert as_dict["rows"] == [2] read_back = daft.read_iceberg(table) assert as_arrow == read_back.to_arrow() + + +@pytest.fixture +def complex_table() -> tuple[pa.Table, Schema]: + table = pa.table( + { + "int": [1, 2, 3], + "float": [1.1, 2.2, 3.3], + "string": ["foo", "bar", "baz"], + "binary": [b"foo", b"bar", b"baz"], + "boolean": [True, False, True], + "timestamp": [ + datetime.datetime(2024, 2, 10), + datetime.datetime(2024, 2, 11), + datetime.datetime(2024, 2, 12), + ], + "date": [datetime.date(2024, 2, 10), datetime.date(2024, 2, 11), datetime.date(2024, 2, 12)], + "decimal": pa.array( + [decimal.Decimal("1234.567"), decimal.Decimal("1233.456"), decimal.Decimal("1232.345")], + type=pa.decimal128(7, 3), + ), + "list": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "struct": [{"x": 1, "y": False}, {"y": True, "z": "foo"}, {"x": 5, "z": "bar"}], + "map": pa.array( + [[("x", 1), ("y", 0)], [("a", 2), ("b", 45)], [("c", 4), ("d", 18)]], + type=pa.map_(pa.string(), pa.int64()), + ), + } + ) + + schema = Schema( + NestedField(field_id=1, name="int", type=LongType()), + NestedField(field_id=2, name="float", type=DoubleType()), + NestedField(field_id=3, name="string", type=StringType()), + NestedField(field_id=4, name="binary", type=BinaryType()), + NestedField(field_id=5, name="boolean", type=BooleanType()), + NestedField(field_id=6, name="timestamp", type=TimestampType()), + NestedField(field_id=7, name="date", type=DateType()), + NestedField(field_id=8, name="decimal", type=DecimalType(7, 3)), + NestedField(field_id=9, name="list", type=ListType(element_id=20, element=LongType())), + NestedField( + field_id=10, + name="struct", + type=StructType( + NestedField(field_id=11, name="x", type=LongType()), + NestedField(field_id=12, name="y", type=BooleanType()), + NestedField(field_id=13, name="z", type=StringType()), + ), + ), + NestedField( + field_id=14, + name="map", + type=MapType(key_id=21, key_type=StringType(), value_id=22, value_type=LongType()), + ), + ) + + return table, schema + + +@pytest.mark.parametrize( + "partition_spec,num_partitions", + [ + pytest.param(UNPARTITIONED_PARTITION_SPEC, 1, id="unpartitioned"), + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="int")), + 3, + id="int_identity_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(2), name="int")), + 2, + id="int_bucket_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(2), name="int")), + 2, + id="int_truncate_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="float")), + 3, + id="float_identity_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="string")), + 3, + id="string_identity_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=BucketTransform(2), name="string")), + 2, + id="string_bucket_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=TruncateTransform(2), name="string")), + 2, + id="string_truncate_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=4, field_id=1000, transform=IdentityTransform(), name="binary")), + 3, + id="binary_identity_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=4, field_id=1000, transform=BucketTransform(2), name="binary")), + 2, + id="binary_bucket_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=4, field_id=1000, transform=TruncateTransform(2), name="binary")), + 2, + id="binary_truncate_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=5, field_id=1000, transform=IdentityTransform(), name="boolean")), + 2, + id="bool_identity_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=6, field_id=1000, transform=IdentityTransform(), name="timestamp")), + 3, + id="datetime_identity_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=6, field_id=1000, transform=BucketTransform(2), name="timestamp")), + 1, + id="datetime_bucket_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=6, field_id=1000, transform=YearTransform(), name="timestamp")), + 1, + id="datetime_year_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=6, field_id=1000, transform=MonthTransform(), name="timestamp")), + 1, + id="datetime_month_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=6, field_id=1000, transform=DayTransform(), name="timestamp")), + 3, + id="datetime_day_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=6, field_id=1000, transform=HourTransform(), name="timestamp")), + 3, + id="datetime_hour_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=7, field_id=1000, transform=IdentityTransform(), name="date")), + 3, + id="date_identity_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=7, field_id=1000, transform=BucketTransform(2), name="date")), + 2, + id="date_bucket_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=7, field_id=1000, transform=YearTransform(), name="date")), + 1, + id="date_year_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=7, field_id=1000, transform=MonthTransform(), name="date")), + 1, + id="date_month_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=7, field_id=1000, transform=DayTransform(), name="date")), + 3, + id="date_day_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=8, field_id=1000, transform=IdentityTransform(), name="decimal")), + 3, + id="decimal_identity_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=8, field_id=1000, transform=BucketTransform(2), name="decimal")), + 1, + id="decimal_bucket_partitioned", + ), + pytest.param( + PartitionSpec(PartitionField(source_id=8, field_id=1000, transform=TruncateTransform(2), name="decimal")), + 3, + id="decimal_truncate_partitioned", + ), + pytest.param( + PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=BucketTransform(2), name="int"), + PartitionField(source_id=3, field_id=1000, transform=TruncateTransform(2), name="string"), + ), + 3, + id="double_partitioned", + ), + ], +) +def test_complex_table_write_read(local_catalog, complex_table, partition_spec, num_partitions): + pa_table, schema = complex_table + table = local_catalog.create_table("default.test", schema, partition_spec=partition_spec) + df = daft.from_arrow(pa_table) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert len(as_dict["operation"]) == num_partitions + assert all(op == "ADD" for op in as_dict["operation"]), as_dict["operation"] + assert sum(as_dict["rows"]) == 3, as_dict["rows"] + read_back = daft.read_iceberg(table) + assert df.to_arrow() == read_back.to_arrow().sort_by("int") diff --git a/tests/series/test_partitioning.py b/tests/series/test_partitioning.py index 10e8eb700f..dea9ee90a8 100644 --- a/tests/series/test_partitioning.py +++ b/tests/series/test_partitioning.py @@ -1,9 +1,11 @@ from __future__ import annotations -from datetime import date, datetime +from datetime import date, datetime, time from decimal import Decimal from itertools import product +import pandas as pd +import pyarrow as pa import pytest from daft import DataType, TimeUnit @@ -29,8 +31,8 @@ def test_partitioning_days(input, dtype, expected): s = Series.from_pylist(input).cast(dtype) d = s.partitioning.days() - assert d.datatype() == DataType.date() - assert d.cast(DataType.int32()).to_pylist() == expected + assert d.datatype() == DataType.int32() + assert d.to_pylist() == expected @pytest.mark.parametrize( @@ -135,6 +137,35 @@ def test_iceberg_bucketing(input, n): seen[v] = b +@pytest.mark.parametrize( + "input,expected", + [ + (pa.array([34], type=pa.int32()), 2017239379), + (pa.array([34], type=pa.int64()), 2017239379), + (pa.array([Decimal("14.20")]), -500754589), + (pa.array([date.fromisoformat("2017-11-16")]), -653330422), + (pa.array([time.fromisoformat("22:31:08")]), -662762989), + (pa.array([datetime.fromisoformat("2017-11-16T22:31:08")]), -2047944441), + (pa.array([datetime.fromisoformat("2017-11-16T22:31:08.000001")]), -1207196810), + (pa.array([datetime.fromisoformat("2017-11-16T14:31:08-08:00")]), -2047944441), + (pa.array([datetime.fromisoformat("2017-11-16T14:31:08.000001-08:00")]), -1207196810), + (pa.array([datetime.fromisoformat("2017-11-16T22:31:08")], type=pa.timestamp("ns")), -2047944441), + (pa.array([pd.to_datetime("2017-11-16T22:31:08.000001001")], type=pa.timestamp("ns")), -1207196810), + (pa.array([datetime.fromisoformat("2017-11-16T14:31:08-08:00")], type=pa.timestamp("ns")), -2047944441), + (pa.array([pd.to_datetime("2017-11-16T14:31:08.000001001-08:00")], type=pa.timestamp("ns")), -1207196810), + (pa.array(["iceberg"]), 1210000089), + (pa.array([b"\x00\x01\x02\x03"]), -188683207), + ], +) +def test_iceberg_bucketing_hash(input, expected): + # https://iceberg.apache.org/spec/#appendix-b-32-bit-hash-requirements + max_buckets = 2**31 - 1 + s = Series.from_arrow(input) + buckets = s.partitioning.iceberg_bucket(max_buckets) + assert buckets.datatype() == DataType.int32() + assert buckets.to_pylist() == [(expected & max_buckets) % max_buckets] + + def test_iceberg_truncate_decimal(): data = ["12.34", "12.30", "12.29", "0.05", "-0.05"] data = [Decimal(v) for v in data] + [None] From 2ae875fe1d9e55050a20769face7ce3f5a52af9e Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Fri, 20 Sep 2024 15:29:38 -0700 Subject: [PATCH 09/35] [CHORE] Move codspeed interactive tests to local files (#2872) --- tests/benchmarks/test_interactive_reads.py | 51 ++++++++++++---------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/tests/benchmarks/test_interactive_reads.py b/tests/benchmarks/test_interactive_reads.py index e5ee4d801c..03af29f0a7 100644 --- a/tests/benchmarks/test_interactive_reads.py +++ b/tests/benchmarks/test_interactive_reads.py @@ -1,14 +1,14 @@ +import tempfile + +import boto3 import pytest +from botocore import UNSIGNED +from botocore.client import Config import daft from daft.io import IOConfig, S3Config -@pytest.fixture -def files(request): - return request.param - - @pytest.fixture def expected_count(request): return request.param @@ -19,16 +19,26 @@ def io_config(): return IOConfig(s3=S3Config(anonymous=True)) +@pytest.fixture(scope="session") +def files(request): + num_files = request.param + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) + with tempfile.TemporaryDirectory() as tmpdir: + local_file = f"{tmpdir}/small-fake-data.parquet.parquet" + + s3.download_file("daft-public-data", "test_fixtures/parquet/small-fake-data.parquet", local_file) + if request.param == 1: + yield local_file + else: + yield [local_file] * num_files + + @pytest.mark.benchmark(group="show") @pytest.mark.parametrize( "files", [ - pytest.param("s3://daft-public-data/test_fixtures/parquet/small-fake-data.parquet", id="1 Small File"), - pytest.param( - 100 * ["s3://daft-public-data/test_fixtures/parquet/small-fake-data.parquet"], id="100 Small Files" - ), - # pytest.param("s3://daft-public-data/test_fixtures/parquet/large-fake-data.parquet", id="1 Large File"), - # pytest.param(100 * ["s3://daft-public-data/test_fixtures/parquet/large-fake-data.parquet"], id="100 Large File"), + pytest.param(1, id="1 Small File"), + pytest.param(100, id="100 Small Files"), ], indirect=True, # This tells pytest to pass the params to the fixture ) @@ -44,11 +54,8 @@ def f(): @pytest.mark.parametrize( "files", [ - pytest.param("s3://daft-public-data/test_fixtures/parquet/small-fake-data.parquet", id="1 Small File"), - pytest.param( - 100 * ["s3://daft-public-data/test_fixtures/parquet/small-fake-data.parquet"], id="100 Small Files" - ), - # pytest.param("s3://daft-public-data/test_fixtures/parquet/large-fake-data.parquet", id="1 Large File"), + pytest.param(1, id="1 Small File"), + pytest.param(100, id="100 Small Files"), ], indirect=True, # This tells pytest to pass the params to the fixture ) @@ -64,9 +71,8 @@ def f(): @pytest.mark.parametrize( "files, expected_count", [ - pytest.param("s3://daft-public-data/test_fixtures/parquet/small-fake-data.parquet", 1024, id="1 Small File"), - # pytest.param(100*["s3://daft-public-data/test_fixtures/parquet/small-fake-data.parquet"], 100 * 1024, id="100 Small Files"), # Turn this back on after we speed up count - # pytest.param("s3://daft-public-data/test_fixtures/parquet/large-fake-data.parquet", 100000000, id="1 Large File"), # Turn this back on after we speed up count + pytest.param(1, 1024, id="1 Small File"), + pytest.param(100, 102400, id="100 Small Files"), ], indirect=True, # This tells pytest to pass the params to the fixture ) @@ -83,11 +89,8 @@ def f(): @pytest.mark.parametrize( "files", [ - pytest.param("s3://daft-public-data/test_fixtures/parquet/small-fake-data.parquet", id="1 Small File"), - pytest.param( - 100 * ["s3://daft-public-data/test_fixtures/parquet/small-fake-data.parquet"], id="100 Small Files" - ), - # pytest.param("s3://daft-public-data/test_fixtures/parquet/large-fake-data.parquet", id="1 Large File"), + pytest.param(1, id="1 Small File"), + pytest.param(100, id="100 Small Files"), ], indirect=True, # This tells pytest to pass the params to the fixture ) From a3c98a54bad28f84fbe0c6c011cd1339a7fc4212 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Fri, 20 Sep 2024 16:12:45 -0700 Subject: [PATCH 10/35] [CHORE] Update documentation for config variables (#2874) --- daft/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/daft/context.py b/daft/context.py index 61b69284af..32ba5a5c65 100644 --- a/daft/context.py +++ b/daft/context.py @@ -299,7 +299,7 @@ def set_execution_config( broadcast_join_size_bytes_threshold: int | None = None, parquet_split_row_groups_max_files: int | None = None, sort_merge_join_sort_with_aligned_boundaries: bool | None = None, - hash_join_partition_size_leniency: bool | None = None, + hash_join_partition_size_leniency: float | None = None, sample_size_for_sort: int | None = None, num_preview_rows: int | None = None, parquet_target_filesize: int | None = None, @@ -344,7 +344,7 @@ def set_execution_config( parquet_inflation_factor: Inflation Factor of parquet files (In-Memory-Size / File-Size) ratio. Defaults to 3.0 csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 - shuffle_aggregation_default_partitions: Minimum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. + shuffle_aggregation_default_partitions: Maximum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. shuffle_join_default_partitions: Minimum number of partitions to create when performing joins. Defaults to 16, unless the number of input partitions is greater than 16. read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB enable_aqe: Enables Adaptive Query Execution, Defaults to False From b0f31e3bd641969797403090ef46fa6db39f9488 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Fri, 20 Sep 2024 16:53:19 -0700 Subject: [PATCH 11/35] Revert "[FEAT]: `shuffle_join_default_partitions` param" (#2873) Reverts Eventual-Inc/Daft#2844 --- daft/context.py | 3 - daft/daft/__init__.pyi | 3 - src/common/daft-config/src/lib.rs | 2 - src/common/daft-config/src/python.rs | 12 -- .../src/physical_planner/translate.rs | 119 +++++------------- tests/dataframe/test_joins.py | 95 +------------- 6 files changed, 33 insertions(+), 201 deletions(-) diff --git a/daft/context.py b/daft/context.py index 32ba5a5c65..caf74ef4d6 100644 --- a/daft/context.py +++ b/daft/context.py @@ -308,7 +308,6 @@ def set_execution_config( csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, - shuffle_join_default_partitions: int | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -345,7 +344,6 @@ def set_execution_config( csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 shuffle_aggregation_default_partitions: Maximum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. - shuffle_join_default_partitions: Minimum number of partitions to create when performing joins. Defaults to 16, unless the number of input partitions is greater than 16. read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB enable_aqe: Enables Adaptive Query Execution, Defaults to False enable_native_executor: Enables new local executor. Defaults to False @@ -371,7 +369,6 @@ def set_execution_config( csv_target_filesize=csv_target_filesize, csv_inflation_factor=csv_inflation_factor, shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions, - shuffle_join_default_partitions=shuffle_join_default_partitions, read_sql_partition_size_bytes=read_sql_partition_size_bytes, enable_aqe=enable_aqe, enable_native_executor=enable_native_executor, diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index d5ef5873c3..5a81e48b79 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1758,7 +1758,6 @@ class PyDaftExecutionConfig: csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, - shuffle_join_default_partitions: int | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -1791,8 +1790,6 @@ class PyDaftExecutionConfig: @property def shuffle_aggregation_default_partitions(self) -> int: ... @property - def shuffle_join_default_partitions(self) -> int: ... - @property def read_sql_partition_size_bytes(self) -> int: ... @property def enable_aqe(self) -> bool: ... diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 153d2a80c5..dcaef0a2f8 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -52,7 +52,6 @@ pub struct DaftExecutionConfig { pub csv_target_filesize: usize, pub csv_inflation_factor: f64, pub shuffle_aggregation_default_partitions: usize, - pub shuffle_join_default_partitions: usize, pub read_sql_partition_size_bytes: usize, pub enable_aqe: bool, pub enable_native_executor: bool, @@ -76,7 +75,6 @@ impl Default for DaftExecutionConfig { csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, shuffle_aggregation_default_partitions: 200, - shuffle_join_default_partitions: 16, read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB enable_aqe: false, enable_native_executor: false, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 818934261a..5dda71eda8 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -94,7 +94,6 @@ impl PyDaftExecutionConfig { csv_target_filesize: Option, csv_inflation_factor: Option, shuffle_aggregation_default_partitions: Option, - shuffle_join_default_partitions: Option, read_sql_partition_size_bytes: Option, enable_aqe: Option, enable_native_executor: Option, @@ -144,16 +143,10 @@ impl PyDaftExecutionConfig { if let Some(csv_inflation_factor) = csv_inflation_factor { config.csv_inflation_factor = csv_inflation_factor; } - if let Some(shuffle_aggregation_default_partitions) = shuffle_aggregation_default_partitions { config.shuffle_aggregation_default_partitions = shuffle_aggregation_default_partitions; } - - if let Some(shuffle_join_default_partitions) = shuffle_join_default_partitions { - config.shuffle_join_default_partitions = shuffle_join_default_partitions; - } - if let Some(read_sql_partition_size_bytes) = read_sql_partition_size_bytes { config.read_sql_partition_size_bytes = read_sql_partition_size_bytes; } @@ -238,11 +231,6 @@ impl PyDaftExecutionConfig { Ok(self.config.shuffle_aggregation_default_partitions) } - #[getter] - fn get_shuffle_join_default_partitions(&self) -> PyResult { - Ok(self.config.shuffle_join_default_partitions) - } - #[getter] fn get_read_sql_partition_size_bytes(&self) -> PyResult { Ok(self.config.read_sql_partition_size_bytes) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 408d4f62a6..639c571871 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -571,7 +571,6 @@ pub(super) fn translate_single_logical_node( "Sort-merge join currently only supports inner joins".to_string(), )); } - let num_partitions = max(num_partitions, cfg.shuffle_join_default_partitions); let needs_presort = if cfg.sort_merge_join_sort_with_aligned_boundaries { // Use the special-purpose presorting that ensures join inputs are sorted with aligned @@ -617,6 +616,7 @@ pub(super) fn translate_single_logical_node( // allow for leniency in partition size to avoid minor repartitions let num_left_partitions = left_clustering_spec.num_partitions(); let num_right_partitions = right_clustering_spec.num_partitions(); + let num_partitions = match ( is_left_hash_partitioned, is_right_hash_partitioned, @@ -637,7 +637,6 @@ pub(super) fn translate_single_logical_node( } (_, _, a, b) => max(a, b), }; - let num_partitions = max(num_partitions, cfg.shuffle_join_default_partitions); if num_left_partitions != num_partitions || (num_partitions > 1 && !is_left_hash_partitioned) @@ -1077,13 +1076,6 @@ mod tests { Self::Reversed(v) => Self::Reversed(v * x), } } - fn unwrap(&self) -> usize { - match self { - Self::Good(v) => *v, - Self::Bad(v) => *v, - Self::Reversed(v) => *v, - } - } } fn force_repartition( @@ -1136,31 +1128,21 @@ mod tests { fn check_physical_matches( plan: PhysicalPlanRef, - left_partition_size: usize, - right_partition_size: usize, left_repartitions: bool, right_repartitions: bool, - shuffle_join_default_partitions: usize, ) -> bool { match plan.as_ref() { PhysicalPlan::HashJoin(HashJoin { left, right, .. }) => { - let left_works = match ( - left.as_ref(), - left_repartitions || left_partition_size < shuffle_join_default_partitions, - ) { + let left_works = match (left.as_ref(), left_repartitions) { (PhysicalPlan::ReduceMerge(_), true) => true, (PhysicalPlan::Project(_), false) => true, _ => false, }; - let right_works = match ( - right.as_ref(), - right_repartitions || right_partition_size < shuffle_join_default_partitions, - ) { + let right_works = match (right.as_ref(), right_repartitions) { (PhysicalPlan::ReduceMerge(_), true) => true, (PhysicalPlan::Project(_), false) => true, _ => false, }; - left_works && right_works } _ => false, @@ -1170,7 +1152,7 @@ mod tests { /// Tests a variety of settings regarding hash join repartitioning. #[test] fn repartition_hash_join_tests() -> DaftResult<()> { - use RepartitionOptions::{Bad, Good, Reversed}; + use RepartitionOptions::*; let cases = vec![ (Good(30), Good(30), false, false), (Good(30), Good(40), true, false), @@ -1188,17 +1170,9 @@ mod tests { let cfg: Arc = DaftExecutionConfig::default().into(); for (l_opts, r_opts, l_exp, r_exp) in cases { for mult in [1, 10] { - let l_opts = l_opts.scale_by(mult); - let r_opts = r_opts.scale_by(mult); - let plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; - if !check_physical_matches( - plan, - l_opts.unwrap(), - r_opts.unwrap(), - l_exp, - r_exp, - cfg.shuffle_join_default_partitions, - ) { + let plan = + get_hash_join_plan(cfg.clone(), l_opts.scale_by(mult), r_opts.scale_by(mult))?; + if !check_physical_matches(plan, l_exp, r_exp) { panic!( "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", l_opts, r_opts, l_exp, r_exp, mult @@ -1206,15 +1180,9 @@ mod tests { } // reversed direction - let plan = get_hash_join_plan(cfg.clone(), r_opts.clone(), l_opts.clone())?; - if !check_physical_matches( - plan, - l_opts.unwrap(), - r_opts.unwrap(), - r_exp, - l_exp, - cfg.shuffle_join_default_partitions, - ) { + let plan = + get_hash_join_plan(cfg.clone(), r_opts.scale_by(mult), l_opts.scale_by(mult))?; + if !check_physical_matches(plan, r_exp, l_exp) { panic!( "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", r_opts, l_opts, r_exp, l_exp, mult @@ -1231,38 +1199,27 @@ mod tests { let mut cfg = DaftExecutionConfig::default(); cfg.hash_join_partition_size_leniency = 0.8; let cfg = Arc::new(cfg); - let (l_opts, r_opts) = (RepartitionOptions::Good(30), RepartitionOptions::Bad(40)); - let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; - assert!(check_physical_matches( - physical_plan, - l_opts.unwrap(), - r_opts.unwrap(), - true, - true, - cfg.shuffle_join_default_partitions - )); - let (l_opts, r_opts) = (RepartitionOptions::Good(20), RepartitionOptions::Bad(25)); - let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; - assert!(check_physical_matches( - physical_plan, - l_opts.unwrap(), - r_opts.unwrap(), - false, - true, - cfg.shuffle_join_default_partitions - )); + let physical_plan = get_hash_join_plan( + cfg.clone(), + RepartitionOptions::Good(20), + RepartitionOptions::Bad(40), + )?; + assert!(check_physical_matches(physical_plan, true, true)); + + let physical_plan = get_hash_join_plan( + cfg.clone(), + RepartitionOptions::Good(20), + RepartitionOptions::Bad(25), + )?; + assert!(check_physical_matches(physical_plan, false, true)); - let (l_opts, r_opts) = (RepartitionOptions::Good(20), RepartitionOptions::Bad(26)); - let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; - assert!(check_physical_matches( - physical_plan, - l_opts.unwrap(), - r_opts.unwrap(), - true, - true, - cfg.shuffle_join_default_partitions - )); + let physical_plan = get_hash_join_plan( + cfg.clone(), + RepartitionOptions::Good(20), + RepartitionOptions::Bad(26), + )?; + assert!(check_physical_matches(physical_plan, true, true)); Ok(()) } @@ -1280,14 +1237,7 @@ mod tests { let cfg: Arc = DaftExecutionConfig::default().into(); for (l_opts, r_opts, l_exp, r_exp) in cases { let plan = get_hash_join_plan(cfg.clone(), l_opts, r_opts)?; - if !check_physical_matches( - plan, - l_opts.unwrap(), - r_opts.unwrap(), - l_exp, - r_exp, - cfg.shuffle_join_default_partitions, - ) { + if !check_physical_matches(plan, l_exp, r_exp) { panic!( "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", l_opts, r_opts, l_exp, r_exp @@ -1296,14 +1246,7 @@ mod tests { // reversed direction let plan = get_hash_join_plan(cfg.clone(), r_opts, l_opts)?; - if !check_physical_matches( - plan, - l_opts.unwrap(), - r_opts.unwrap(), - r_exp, - l_exp, - cfg.shuffle_join_default_partitions, - ) { + if !check_physical_matches(plan, r_exp, l_exp) { panic!( "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", r_opts, l_opts, r_exp, l_exp diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 8ccc3f72cd..b0bdbf9df4 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -3,16 +3,14 @@ import pyarrow as pa import pytest -import daft -from daft import col -from daft.context import get_context +from daft import col, context from daft.datatype import DataType from daft.errors import ExpressionTypeError from tests.utils import sort_arrow_table def skip_invalid_join_strategies(join_strategy, join_type): - if get_context().daft_execution_config.enable_native_executor is True: + if context.get_context().daft_execution_config.enable_native_executor is True: if join_type == "outer" or join_strategy not in [None, "hash"]: pytest.skip("Native executor fails for these tests") else: @@ -1077,92 +1075,3 @@ def test_join_same_name_alias_with_compute(join_strategy, join_type, expected, m assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "a") == sort_arrow_table( pa.Table.from_pydict(expected), "a" ) - - -# the partition size should be the max(shuffle_join_default_partitions, max(left_partition_size, right_partition_size)) -@pytest.mark.parametrize("shuffle_join_default_partitions", [None, 20]) -def test_join_result_partitions_smaller_than_input(shuffle_join_default_partitions): - skip_invalid_join_strategies("hash", "inner") - if shuffle_join_default_partitions is None: - min_partitions = get_context().daft_execution_config.shuffle_join_default_partitions - else: - min_partitions = shuffle_join_default_partitions - - with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): - right_partition_size = 50 - for left_partition_size in [1, min_partitions, min_partitions + 1]: - df_left = daft.from_pydict( - {"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]} - ) - df_left = df_left.into_partitions(left_partition_size) - - df_right = daft.from_pydict( - {"group": [i for i in range(right_partition_size)], "value": [i for i in range(right_partition_size)]} - ) - - df_right = df_right.into_partitions(right_partition_size) - - actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() - n_partitions = actual.num_partitions() - expected_n_partitions = max(min_partitions, left_partition_size, right_partition_size) - assert n_partitions == expected_n_partitions - - -def test_join_right_single_partition(): - skip_invalid_join_strategies("hash", "inner") - shuffle_join_default_partitions = 16 - df_left = daft.from_pydict({"group": [i for i in range(300)], "value": [i for i in range(300)]}).repartition( - 300, "group" - ) - - df_right = daft.from_pydict({"group": [i for i in range(100)], "value": [i for i in range(100)]}).repartition( - 1, "group" - ) - - with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): - actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() - n_partitions = actual.num_partitions() - assert n_partitions == 300 - - -def test_join_right_smaller_than_cfg(): - skip_invalid_join_strategies("hash", "inner") - shuffle_join_default_partitions = 200 - df_left = daft.from_pydict({"group": [i for i in range(199)], "value": [i for i in range(199)]}).repartition( - 199, "group" - ) - - df_right = daft.from_pydict({"group": [i for i in range(100)], "value": [i for i in range(100)]}).repartition( - 100, "group" - ) - - with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): - actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() - n_partitions = actual.num_partitions() - assert n_partitions == 200 - - -# for sort_merge, the result partitions should always be max(shuffle_join_default_partitions, max(left_partition_size, right_partition_size)) -@pytest.mark.parametrize("shuffle_join_default_partitions", [None, 20]) -def test_join_result_partitions_for_sortmerge(shuffle_join_default_partitions): - skip_invalid_join_strategies("sort_merge", "inner") - - if shuffle_join_default_partitions is None: - min_partitions = get_context().daft_execution_config.shuffle_join_default_partitions - else: - min_partitions = shuffle_join_default_partitions - - with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): - for partition_size in [1, min_partitions, min_partitions + 1]: - df_left = daft.from_pydict( - {"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]} - ) - df_left = df_left.into_partitions(partition_size) - - df_right = daft.from_pydict({"group": [i for i in range(50)], "value": [i for i in range(50)]}) - - df_right = df_right.into_partitions(50) - - actual = df_left.join(df_right, on="group", how="inner", strategy="sort_merge").collect() - - assert actual.num_partitions() == max(min_partitions, partition_size, 50) From 134601b8024226a01f0b20a890b434bb279e803b Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Mon, 23 Sep 2024 11:13:47 -0500 Subject: [PATCH 12/35] [CHORE]: move `numeric` out of daft-dsl and into `daft-functions` (#2857) --- Cargo.lock | 1 + daft/daft/__init__.pyi | 52 ++-- daft/expressions/expressions.py | 48 +-- src/daft-dsl/src/functions/mod.rs | 7 +- src/daft-dsl/src/functions/numeric/abs.rs | 40 --- src/daft-dsl/src/functions/numeric/ceil.rs | 40 --- src/daft-dsl/src/functions/numeric/exp.rs | 46 --- src/daft-dsl/src/functions/numeric/floor.rs | 40 --- src/daft-dsl/src/functions/numeric/log.rs | 68 ----- src/daft-dsl/src/functions/numeric/mod.rs | 283 ------------------ src/daft-dsl/src/functions/numeric/round.rs | 44 --- src/daft-dsl/src/functions/numeric/sign.rs | 40 --- src/daft-dsl/src/functions/numeric/sqrt.rs | 37 --- .../src/functions/numeric/trigonometry.rs | 87 ------ src/daft-dsl/src/python.rs | 127 +------- src/daft-functions/Cargo.toml | 1 + src/daft-functions/src/lib.rs | 2 +- src/daft-functions/src/numeric/abs.rs | 48 +++ src/daft-functions/src/numeric/cbrt.rs | 57 ++-- src/daft-functions/src/numeric/ceil.rs | 49 +++ src/daft-functions/src/numeric/exp.rs | 67 +++++ src/daft-functions/src/numeric/floor.rs | 50 ++++ src/daft-functions/src/numeric/log.rs | 142 +++++++++ src/daft-functions/src/numeric/mod.rs | 103 +++++++ src/daft-functions/src/numeric/round.rs | 59 ++++ src/daft-functions/src/numeric/sign.rs | 50 ++++ src/daft-functions/src/numeric/sqrt.rs | 51 ++++ .../src/numeric/trigonometry.rs | 222 ++++++++++++++ src/daft-sql/src/modules/numeric.rs | 150 ++++++---- src/daft-sql/src/planner.rs | 6 +- 30 files changed, 1015 insertions(+), 1002 deletions(-) delete mode 100644 src/daft-dsl/src/functions/numeric/abs.rs delete mode 100644 src/daft-dsl/src/functions/numeric/ceil.rs delete mode 100644 src/daft-dsl/src/functions/numeric/exp.rs delete mode 100644 src/daft-dsl/src/functions/numeric/floor.rs delete mode 100644 src/daft-dsl/src/functions/numeric/log.rs delete mode 100644 src/daft-dsl/src/functions/numeric/mod.rs delete mode 100644 src/daft-dsl/src/functions/numeric/round.rs delete mode 100644 src/daft-dsl/src/functions/numeric/sign.rs delete mode 100644 src/daft-dsl/src/functions/numeric/sqrt.rs delete mode 100644 src/daft-dsl/src/functions/numeric/trigonometry.rs create mode 100644 src/daft-functions/src/numeric/abs.rs create mode 100644 src/daft-functions/src/numeric/ceil.rs create mode 100644 src/daft-functions/src/numeric/exp.rs create mode 100644 src/daft-functions/src/numeric/floor.rs create mode 100644 src/daft-functions/src/numeric/log.rs create mode 100644 src/daft-functions/src/numeric/round.rs create mode 100644 src/daft-functions/src/numeric/sign.rs create mode 100644 src/daft-functions/src/numeric/sqrt.rs create mode 100644 src/daft-functions/src/numeric/trigonometry.rs diff --git a/Cargo.lock b/Cargo.lock index 4815488c2b..8dc27f4c89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1819,6 +1819,7 @@ dependencies = [ "base64 0.22.1", "bytes", "common-error", + "common-hashable-float-wrapper", "common-io-config", "daft-core", "daft-dsl", diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 5a81e48b79..7c8f04a7fc 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1040,29 +1040,6 @@ class PySchema: class PyExpr: def alias(self, name: str) -> PyExpr: ... def cast(self, dtype: PyDataType) -> PyExpr: ... - def ceil(self) -> PyExpr: ... - def floor(self) -> PyExpr: ... - def sign(self) -> PyExpr: ... - def round(self, decimal: int) -> PyExpr: ... - def sqrt(self) -> PyExpr: ... - def sin(self) -> PyExpr: ... - def cos(self) -> PyExpr: ... - def tan(self) -> PyExpr: ... - def cot(self) -> PyExpr: ... - def arcsin(self) -> PyExpr: ... - def arccos(self) -> PyExpr: ... - def arctan(self) -> PyExpr: ... - def arctan2(self, other: PyExpr) -> PyExpr: ... - def arctanh(self) -> PyExpr: ... - def arccosh(self) -> PyExpr: ... - def arcsinh(self) -> PyExpr: ... - def degrees(self) -> PyExpr: ... - def radians(self) -> PyExpr: ... - def log2(self) -> PyExpr: ... - def log10(self) -> PyExpr: ... - def log(self, base: float) -> PyExpr: ... - def ln(self) -> PyExpr: ... - def exp(self) -> PyExpr: ... def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ... def count(self, mode: CountMode) -> PyExpr: ... def sum(self) -> PyExpr: ... @@ -1074,7 +1051,6 @@ class PyExpr: def any_value(self, ignore_nulls: bool) -> PyExpr: ... def agg_list(self) -> PyExpr: ... def agg_concat(self) -> PyExpr: ... - def __abs__(self) -> PyExpr: ... def __add__(self, other: PyExpr) -> PyExpr: ... def __sub__(self, other: PyExpr) -> PyExpr: ... def __mul__(self, other: PyExpr) -> PyExpr: ... @@ -1217,9 +1193,35 @@ def minhash( def sql(sql: str, catalog: PyCatalog, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def sql_expr(sql: str) -> PyExpr: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... -def cbrt(expr: PyExpr) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... +# expr numeric ops +def abs(expr: PyExpr) -> PyExpr: ... +def cbrt(expr: PyExpr) -> PyExpr: ... +def ceil(expr: PyExpr) -> PyExpr: ... +def exp(expr: PyExpr) -> PyExpr: ... +def floor(expr: PyExpr) -> PyExpr: ... +def log2(expr: PyExpr) -> PyExpr: ... +def log10(expr: PyExpr) -> PyExpr: ... +def log(expr: PyExpr, base: float) -> PyExpr: ... +def ln(expr: PyExpr) -> PyExpr: ... +def round(expr: PyExpr, decimal: int) -> PyExpr: ... +def sign(expr: PyExpr) -> PyExpr: ... +def sqrt(expr: PyExpr) -> PyExpr: ... +def sin(expr: PyExpr) -> PyExpr: ... +def cos(expr: PyExpr) -> PyExpr: ... +def tan(expr: PyExpr) -> PyExpr: ... +def cot(expr: PyExpr) -> PyExpr: ... +def arcsin(expr: PyExpr) -> PyExpr: ... +def arccos(expr: PyExpr) -> PyExpr: ... +def arctan(expr: PyExpr) -> PyExpr: ... +def arctan2(expr: PyExpr, other: PyExpr) -> PyExpr: ... +def radians(expr: PyExpr) -> PyExpr: ... +def degrees(expr: PyExpr) -> PyExpr: ... +def arctanh(expr: PyExpr) -> PyExpr: ... +def arccosh(expr: PyExpr) -> PyExpr: ... +def arcsinh(expr: PyExpr) -> PyExpr: ... + # --- # expr.image namespace # --- diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index a2232d9a24..1ae7e90dac 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -328,7 +328,7 @@ def __abs__(self) -> Expression: def abs(self) -> Expression: """Absolute of a numeric expression (``expr.abs()``)""" - return Expression._from_pyexpr(abs(self._expr)) + return Expression._from_pyexpr(native.abs(self._expr)) def __add__(self, other: object) -> Expression: """Adds two numeric expressions or concatenates two string expressions (``e1 + e2``)""" @@ -577,17 +577,17 @@ def cast(self, dtype: DataType) -> Expression: def ceil(self) -> Expression: """The ceiling of a numeric expression (``expr.ceil()``)""" - expr = self._expr.ceil() + expr = native.ceil(self._expr) return Expression._from_pyexpr(expr) def floor(self) -> Expression: """The floor of a numeric expression (``expr.floor()``)""" - expr = self._expr.floor() + expr = native.floor(self._expr) return Expression._from_pyexpr(expr) def sign(self) -> Expression: """The sign of a numeric expression (``expr.sign()``)""" - expr = self._expr.sign() + expr = native.sign(self._expr) return Expression._from_pyexpr(expr) def round(self, decimals: int = 0) -> Expression: @@ -597,12 +597,12 @@ def round(self, decimals: int = 0) -> Expression: decimals: number of decimal places to round to. Defaults to 0. """ assert isinstance(decimals, int) - expr = self._expr.round(decimals) + expr = native.round(self._expr, decimals) return Expression._from_pyexpr(expr) def sqrt(self) -> Expression: """The square root of a numeric expression (``expr.sqrt()``)""" - expr = self._expr.sqrt() + expr = native.sqrt(self._expr) return Expression._from_pyexpr(expr) def cbrt(self) -> Expression: @@ -611,37 +611,37 @@ def cbrt(self) -> Expression: def sin(self) -> Expression: """The elementwise sine of a numeric expression (``expr.sin()``)""" - expr = self._expr.sin() + expr = native.sin(self._expr) return Expression._from_pyexpr(expr) def cos(self) -> Expression: """The elementwise cosine of a numeric expression (``expr.cos()``)""" - expr = self._expr.cos() + expr = native.cos(self._expr) return Expression._from_pyexpr(expr) def tan(self) -> Expression: """The elementwise tangent of a numeric expression (``expr.tan()``)""" - expr = self._expr.tan() + expr = native.tan(self._expr) return Expression._from_pyexpr(expr) def cot(self) -> Expression: """The elementwise cotangent of a numeric expression (``expr.cot()``)""" - expr = self._expr.cot() + expr = native.cot(self._expr) return Expression._from_pyexpr(expr) def arcsin(self) -> Expression: """The elementwise arc sine of a numeric expression (``expr.arcsin()``)""" - expr = self._expr.arcsin() + expr = native.arcsin(self._expr) return Expression._from_pyexpr(expr) def arccos(self) -> Expression: """The elementwise arc cosine of a numeric expression (``expr.arccos()``)""" - expr = self._expr.arccos() + expr = native.arccos(self._expr) return Expression._from_pyexpr(expr) def arctan(self) -> Expression: """The elementwise arc tangent of a numeric expression (``expr.arctan()``)""" - expr = self._expr.arctan() + expr = native.arctan(self._expr) return Expression._from_pyexpr(expr) def arctan2(self, other: Expression) -> Expression: @@ -652,41 +652,41 @@ def arctan2(self, other: Expression) -> Expression: * ``y >= 0``: ``(pi/2, pi]`` * ``y < 0``: ``(-pi, -pi/2)``""" expr = Expression._to_expression(other) - return Expression._from_pyexpr(self._expr.arctan2(expr._expr)) + return Expression._from_pyexpr(native.arctan2(self._expr, expr._expr)) def arctanh(self) -> Expression: """The elementwise inverse hyperbolic tangent of a numeric expression (``expr.arctanh()``)""" - expr = self._expr.arctanh() + expr = native.arctanh(self._expr) return Expression._from_pyexpr(expr) def arccosh(self) -> Expression: """The elementwise inverse hyperbolic cosine of a numeric expression (``expr.arccosh()``)""" - expr = self._expr.arccosh() + expr = native.arccosh(self._expr) return Expression._from_pyexpr(expr) def arcsinh(self) -> Expression: """The elementwise inverse hyperbolic sine of a numeric expression (``expr.arcsinh()``)""" - expr = self._expr.arcsinh() + expr = native.arcsinh(self._expr) return Expression._from_pyexpr(expr) def radians(self) -> Expression: """The elementwise radians of a numeric expression (``expr.radians()``)""" - expr = self._expr.radians() + expr = native.radians(self._expr) return Expression._from_pyexpr(expr) def degrees(self) -> Expression: """The elementwise degrees of a numeric expression (``expr.degrees()``)""" - expr = self._expr.degrees() + expr = native.degrees(self._expr) return Expression._from_pyexpr(expr) def log2(self) -> Expression: """The elementwise log base 2 of a numeric expression (``expr.log2()``)""" - expr = self._expr.log2() + expr = native.log2(self._expr) return Expression._from_pyexpr(expr) def log10(self) -> Expression: """The elementwise log base 10 of a numeric expression (``expr.log10()``)""" - expr = self._expr.log10() + expr = native.log10(self._expr) return Expression._from_pyexpr(expr) def log(self, base: float = math.e) -> Expression: # type: ignore @@ -695,17 +695,17 @@ def log(self, base: float = math.e) -> Expression: # type: ignore base: The base of the logarithm. Defaults to e. """ assert isinstance(base, (int, float)), f"base must be an int or float, but {type(base)} was provided." - expr = self._expr.log(float(base)) + expr = native.log(self._expr, float(base)) return Expression._from_pyexpr(expr) def ln(self) -> Expression: """The elementwise natural log of a numeric expression (``expr.ln()``)""" - expr = self._expr.ln() + expr = native.ln(self._expr) return Expression._from_pyexpr(expr) def exp(self) -> Expression: """The e^self of a numeric expression (``expr.exp()``)""" - expr = self._expr.exp() + expr = native.exp(self._expr) return Expression._from_pyexpr(expr) def bitwise_and(self, other: Expression) -> Expression: diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 4ff2375b68..0386d7c54c 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -1,5 +1,4 @@ pub mod map; -pub mod numeric; pub mod partitioning; pub mod scalar; pub mod sketch; @@ -17,8 +16,8 @@ pub use scalar::*; use serde::{Deserialize, Serialize}; use self::{ - map::MapExpr, numeric::NumericExpr, partitioning::PartitioningExpr, sketch::SketchExpr, - struct_::StructExpr, utf8::Utf8Expr, + map::MapExpr, partitioning::PartitioningExpr, sketch::SketchExpr, struct_::StructExpr, + utf8::Utf8Expr, }; use crate::{Expr, ExprRef, Operator}; @@ -27,7 +26,6 @@ use python::PythonUDF; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { - Numeric(NumericExpr), Utf8(Utf8Expr), Map(MapExpr), Sketch(SketchExpr), @@ -52,7 +50,6 @@ impl FunctionExpr { fn get_evaluator(&self) -> &dyn FunctionEvaluator { use FunctionExpr::*; match self { - Numeric(expr) => expr.get_evaluator(), Utf8(expr) => expr.get_evaluator(), Map(expr) => expr.get_evaluator(), Sketch(expr) => expr.get_evaluator(), diff --git a/src/daft-dsl/src/functions/numeric/abs.rs b/src/daft-dsl/src/functions/numeric/abs.rs deleted file mode 100644 index af8566960d..0000000000 --- a/src/daft-dsl/src/functions/numeric/abs.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct AbsEvaluator {} - -impl FunctionEvaluator for AbsEvaluator { - fn fn_name(&self) -> &'static str { - "abs" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to abs to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().abs() - } -} diff --git a/src/daft-dsl/src/functions/numeric/ceil.rs b/src/daft-dsl/src/functions/numeric/ceil.rs deleted file mode 100644 index 735ab91bc3..0000000000 --- a/src/daft-dsl/src/functions/numeric/ceil.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct CeilEvaluator {} - -impl FunctionEvaluator for CeilEvaluator { - fn fn_name(&self) -> &'static str { - "ceil" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to ceil to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().ceil() - } -} diff --git a/src/daft-dsl/src/functions/numeric/exp.rs b/src/daft-dsl/src/functions/numeric/exp.rs deleted file mode 100644 index bde9b90f6f..0000000000 --- a/src/daft-dsl/src/functions/numeric/exp.rs +++ /dev/null @@ -1,46 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use crate::{ - functions::{FunctionEvaluator, FunctionExpr}, - ExprRef, -}; - -pub(super) struct ExpEvaluator {} - -impl FunctionEvaluator for ExpEvaluator { - fn fn_name(&self) -> &'static str { - "exp" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - }; - let field = inputs.first().unwrap().to_field(schema)?; - let dtype = match field.dtype { - DataType::Float32 => DataType::Float32, - dt if dt.is_numeric() => DataType::Float64, - _ => { - return Err(DaftError::TypeError(format!( - "Expected input to compute exp to be numeric, got {}", - field.dtype - ))) - } - }; - Ok(Field::new(field.name, dtype)) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().exp() - } -} diff --git a/src/daft-dsl/src/functions/numeric/floor.rs b/src/daft-dsl/src/functions/numeric/floor.rs deleted file mode 100644 index a76de5fda9..0000000000 --- a/src/daft-dsl/src/functions/numeric/floor.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct FloorEvaluator {} - -impl FunctionEvaluator for FloorEvaluator { - fn fn_name(&self) -> &'static str { - "floor" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to floor to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().floor() - } -} diff --git a/src/daft-dsl/src/functions/numeric/log.rs b/src/daft-dsl/src/functions/numeric/log.rs deleted file mode 100644 index 9c6105c449..0000000000 --- a/src/daft-dsl/src/functions/numeric/log.rs +++ /dev/null @@ -1,68 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, NumericExpr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) enum LogFunction { - Log2, - Log10, - Log, - Ln, -} -pub(super) struct LogEvaluator(pub LogFunction); - -impl FunctionEvaluator for LogEvaluator { - fn fn_name(&self) -> &'static str { - match self.0 { - LogFunction::Log2 => "log2", - LogFunction::Log10 => "log10", - LogFunction::Log => "log", - LogFunction::Ln => "ln", - } - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - let dtype = match field.dtype { - DataType::Float32 => DataType::Float32, - dt if dt.is_numeric() => DataType::Float64, - _ => { - return Err(DaftError::TypeError(format!( - "Expected input to log to be numeric, got {}", - field.dtype - ))) - } - }; - Ok(Field::new(field.name, dtype)) - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let input = inputs.first().unwrap(); - match self.0 { - LogFunction::Log2 => input.log2(), - LogFunction::Log10 => input.log10(), - LogFunction::Log => { - let base = match expr { - FunctionExpr::Numeric(NumericExpr::Log(value)) => value, - _ => panic!("Expected Log Expr, got {expr}"), - }; - - input.log(base.0) - } - LogFunction::Ln => input.ln(), - } - } -} diff --git a/src/daft-dsl/src/functions/numeric/mod.rs b/src/daft-dsl/src/functions/numeric/mod.rs deleted file mode 100644 index 98ab6d3f32..0000000000 --- a/src/daft-dsl/src/functions/numeric/mod.rs +++ /dev/null @@ -1,283 +0,0 @@ -mod abs; -mod ceil; -mod exp; -mod floor; -mod log; -mod round; -mod sign; -mod sqrt; -mod trigonometry; - -use std::hash::Hash; - -use abs::AbsEvaluator; -use ceil::CeilEvaluator; -use common_hashable_float_wrapper::FloatWrapper; -use floor::FloorEvaluator; -use log::LogEvaluator; -use round::RoundEvaluator; -use serde::{Deserialize, Serialize}; -use sign::SignEvaluator; -use sqrt::SqrtEvaluator; -use trigonometry::Atan2Evaluator; - -use super::FunctionEvaluator; -use crate::{ - functions::numeric::{ - exp::ExpEvaluator, - trigonometry::{TrigonometricFunction, TrigonometryEvaluator}, - }, - Expr, ExprRef, -}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum NumericExpr { - Abs, - Ceil, - Floor, - Sign, - Round(i32), - Sqrt, - Sin, - Cos, - Tan, - Cot, - ArcSin, - ArcCos, - ArcTan, - ArcTan2, - Radians, - Degrees, - Log2, - Log10, - Log(FloatWrapper), - Ln, - Exp, - ArcTanh, - ArcCosh, - ArcSinh, -} - -impl NumericExpr { - #[inline] - pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - match self { - NumericExpr::Abs => &AbsEvaluator {}, - NumericExpr::Ceil => &CeilEvaluator {}, - NumericExpr::Floor => &FloorEvaluator {}, - NumericExpr::Sign => &SignEvaluator {}, - NumericExpr::Round(_) => &RoundEvaluator {}, - NumericExpr::Sqrt => &SqrtEvaluator {}, - NumericExpr::Sin => &TrigonometryEvaluator(TrigonometricFunction::Sin), - NumericExpr::Cos => &TrigonometryEvaluator(TrigonometricFunction::Cos), - NumericExpr::Tan => &TrigonometryEvaluator(TrigonometricFunction::Tan), - NumericExpr::Cot => &TrigonometryEvaluator(TrigonometricFunction::Cot), - NumericExpr::ArcSin => &TrigonometryEvaluator(TrigonometricFunction::ArcSin), - NumericExpr::ArcCos => &TrigonometryEvaluator(TrigonometricFunction::ArcCos), - NumericExpr::ArcTan => &TrigonometryEvaluator(TrigonometricFunction::ArcTan), - NumericExpr::ArcTan2 => &Atan2Evaluator {}, - NumericExpr::Radians => &TrigonometryEvaluator(TrigonometricFunction::Radians), - NumericExpr::Degrees => &TrigonometryEvaluator(TrigonometricFunction::Degrees), - NumericExpr::Log2 => &LogEvaluator(log::LogFunction::Log2), - NumericExpr::Log10 => &LogEvaluator(log::LogFunction::Log10), - NumericExpr::Log(_) => &LogEvaluator(log::LogFunction::Log), - NumericExpr::Ln => &LogEvaluator(log::LogFunction::Ln), - NumericExpr::Exp => &ExpEvaluator {}, - NumericExpr::ArcTanh => &TrigonometryEvaluator(TrigonometricFunction::ArcTanh), - NumericExpr::ArcCosh => &TrigonometryEvaluator(TrigonometricFunction::ArcCosh), - NumericExpr::ArcSinh => &TrigonometryEvaluator(TrigonometricFunction::ArcSinh), - } - } -} - -pub fn abs(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Abs), - inputs: vec![input], - } - .into() -} - -pub fn ceil(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Ceil), - inputs: vec![input], - } - .into() -} - -pub fn floor(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Floor), - inputs: vec![input], - } - .into() -} - -pub fn sign(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Sign), - inputs: vec![input], - } - .into() -} - -pub fn round(input: ExprRef, decimal: i32) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Round(decimal)), - inputs: vec![input], - } - .into() -} - -pub fn sqrt(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Sqrt), - inputs: vec![input], - } - .into() -} - -pub fn sin(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Sin), - inputs: vec![input], - } - .into() -} - -pub fn cos(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Cos), - inputs: vec![input], - } - .into() -} - -pub fn tan(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Tan), - inputs: vec![input], - } - .into() -} - -pub fn cot(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Cot), - inputs: vec![input], - } - .into() -} - -pub fn arcsin(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcSin), - inputs: vec![input], - } - .into() -} - -pub fn arccos(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcCos), - inputs: vec![input], - } - .into() -} - -pub fn arctan(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcTan), - inputs: vec![input], - } - .into() -} - -pub fn arctan2(input: ExprRef, other: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcTan2), - inputs: vec![input, other], - } - .into() -} - -pub fn radians(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Radians), - inputs: vec![input], - } - .into() -} - -pub fn degrees(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Degrees), - inputs: vec![input], - } - .into() -} - -pub fn arctanh(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcTanh), - inputs: vec![input], - } - .into() -} - -pub fn arccosh(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcCosh), - inputs: vec![input], - } - .into() -} - -pub fn arcsinh(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcSinh), - inputs: vec![input], - } - .into() -} - -pub fn log2(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Log2), - inputs: vec![input], - } - .into() -} - -pub fn log10(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Log10), - inputs: vec![input], - } - .into() -} - -pub fn log(input: ExprRef, base: f64) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Log(FloatWrapper(base))), - inputs: vec![input], - } - .into() -} - -pub fn ln(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Ln), - inputs: vec![input], - } - .into() -} - -pub fn exp(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Exp), - inputs: vec![input], - } - .into() -} diff --git a/src/daft-dsl/src/functions/numeric/round.rs b/src/daft-dsl/src/functions/numeric/round.rs deleted file mode 100644 index aee8fa29c6..0000000000 --- a/src/daft-dsl/src/functions/numeric/round.rs +++ /dev/null @@ -1,44 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, NumericExpr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct RoundEvaluator {} - -impl FunctionEvaluator for RoundEvaluator { - fn fn_name(&self) -> &'static str { - "round" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to round to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let decimal = match expr { - FunctionExpr::Numeric(NumericExpr::Round(index)) => index, - _ => panic!("Expected Round Expr, got {expr}"), - }; - inputs.first().unwrap().round(*decimal) - } -} diff --git a/src/daft-dsl/src/functions/numeric/sign.rs b/src/daft-dsl/src/functions/numeric/sign.rs deleted file mode 100644 index 20dafba799..0000000000 --- a/src/daft-dsl/src/functions/numeric/sign.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct SignEvaluator {} - -impl FunctionEvaluator for SignEvaluator { - fn fn_name(&self) -> &'static str { - "sign" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to sign to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().sign() - } -} diff --git a/src/daft-dsl/src/functions/numeric/sqrt.rs b/src/daft-dsl/src/functions/numeric/sqrt.rs deleted file mode 100644 index 248ce1de7e..0000000000 --- a/src/daft-dsl/src/functions/numeric/sqrt.rs +++ /dev/null @@ -1,37 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct SqrtEvaluator {} - -impl FunctionEvaluator for SqrtEvaluator { - fn fn_name(&self) -> &'static str { - "sqrt" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [first] => { - let field = first.to_field(schema)?; - let dtype = field.dtype.to_floating_representation()?; - Ok(Field::new(field.name, dtype)) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [first] => first.sqrt(), - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/numeric/trigonometry.rs b/src/daft-dsl/src/functions/numeric/trigonometry.rs deleted file mode 100644 index 9779e802d4..0000000000 --- a/src/daft-dsl/src/functions/numeric/trigonometry.rs +++ /dev/null @@ -1,87 +0,0 @@ -use common_error::{DaftError, DaftResult}; -pub use daft_core::array::ops::trigonometry::TrigonometricFunction; -use daft_core::prelude::*; - -use crate::{ - functions::{FunctionEvaluator, FunctionExpr}, - ExprRef, -}; - -pub(super) struct TrigonometryEvaluator(pub TrigonometricFunction); - -impl FunctionEvaluator for TrigonometryEvaluator { - fn fn_name(&self) -> &'static str { - self.0.fn_name() - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - }; - let field = inputs.first().unwrap().to_field(schema)?; - let dtype = match field.dtype { - DataType::Float32 => DataType::Float32, - dt if dt.is_numeric() => DataType::Float64, - _ => { - return Err(DaftError::TypeError(format!( - "Expected input to trigonometry to be numeric, got {}", - field.dtype - ))) - } - }; - Ok(Field::new(field.name, dtype)) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().trigonometry(&self.0) - } -} - -pub(super) struct Atan2Evaluator {} - -impl FunctionEvaluator for Atan2Evaluator { - fn fn_name(&self) -> &'static str { - "atan2" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 2 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))); - } - let field1 = inputs.first().unwrap().to_field(schema)?; - let field2 = inputs.get(1).unwrap().to_field(schema)?; - let dtype = match (field1.dtype, field2.dtype) { - (DataType::Float32, DataType::Float32) => DataType::Float32, - (dt1, dt2) if dt1.is_numeric() && dt2.is_numeric() => DataType::Float64, - (dt1, dt2) => { - return Err(DaftError::TypeError(format!( - "Expected inputs to atan2 to be numeric, got {} and {}", - dt1, dt2 - ))) - } - }; - Ok(Field::new(field1.name, dtype)) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 2 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().atan2(inputs.get(1).unwrap()) - } -} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 7c5a1d7930..af56dc68d8 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -1,5 +1,4 @@ #![allow(non_snake_case)] - use std::{ collections::{hash_map::DefaultHasher, HashMap}, hash::{Hash, Hasher}, @@ -22,7 +21,7 @@ use pyo3::{ }; use serde::{Deserialize, Serialize}; -use crate::{functions, Expr, ExprRef, LiteralValue}; +use crate::{Expr, ExprRef, LiteralValue}; #[pyfunction] pub fn col(name: &str) -> PyResult { @@ -274,126 +273,6 @@ impl PyExpr { Ok(self.expr.clone().cast(&dtype.into()).into()) } - pub fn ceil(&self) -> PyResult { - use functions::numeric::ceil; - Ok(ceil(self.into()).into()) - } - - pub fn floor(&self) -> PyResult { - use functions::numeric::floor; - Ok(floor(self.into()).into()) - } - - pub fn sign(&self) -> PyResult { - use functions::numeric::sign; - Ok(sign(self.into()).into()) - } - - pub fn round(&self, decimal: i32) -> PyResult { - use functions::numeric::round; - if decimal < 0 { - return Err(PyValueError::new_err(format!( - "decimal can not be negative: {decimal}" - ))); - } - Ok(round(self.into(), decimal).into()) - } - - pub fn sqrt(&self) -> PyResult { - use functions::numeric::sqrt; - Ok(sqrt(self.into()).into()) - } - - pub fn sin(&self) -> PyResult { - use functions::numeric::sin; - Ok(sin(self.into()).into()) - } - - pub fn cos(&self) -> PyResult { - use functions::numeric::cos; - Ok(cos(self.into()).into()) - } - - pub fn tan(&self) -> PyResult { - use functions::numeric::tan; - Ok(tan(self.into()).into()) - } - - pub fn cot(&self) -> PyResult { - use functions::numeric::cot; - Ok(cot(self.into()).into()) - } - - pub fn arcsin(&self) -> PyResult { - use functions::numeric::arcsin; - Ok(arcsin(self.into()).into()) - } - - pub fn arccos(&self) -> PyResult { - use functions::numeric::arccos; - Ok(arccos(self.into()).into()) - } - - pub fn arctan(&self) -> PyResult { - use functions::numeric::arctan; - Ok(arctan(self.into()).into()) - } - - pub fn arctan2(&self, other: &Self) -> PyResult { - use functions::numeric::arctan2; - Ok(arctan2(self.into(), other.expr.clone()).into()) - } - - pub fn radians(&self) -> PyResult { - use functions::numeric::radians; - Ok(radians(self.into()).into()) - } - - pub fn degrees(&self) -> PyResult { - use functions::numeric::degrees; - Ok(degrees(self.into()).into()) - } - - pub fn arctanh(&self) -> PyResult { - use functions::numeric::arctanh; - Ok(arctanh(self.into()).into()) - } - - pub fn arccosh(&self) -> PyResult { - use functions::numeric::arccosh; - Ok(arccosh(self.into()).into()) - } - - pub fn arcsinh(&self) -> PyResult { - use functions::numeric::arcsinh; - Ok(arcsinh(self.into()).into()) - } - - pub fn log2(&self) -> PyResult { - use functions::numeric::log2; - Ok(log2(self.into()).into()) - } - - pub fn log10(&self) -> PyResult { - use functions::numeric::log10; - Ok(log10(self.into()).into()) - } - - pub fn log(&self, base: f64) -> PyResult { - use functions::numeric::log; - Ok(log(self.into(), base).into()) - } - - pub fn ln(&self) -> PyResult { - use functions::numeric::ln; - Ok(ln(self.into()).into()) - } - - pub fn exp(&self) -> PyResult { - use functions::numeric::exp; - Ok(exp(self.into()).into()) - } - pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult { Ok(self .expr @@ -460,10 +339,6 @@ impl PyExpr { Ok(self.expr.clone().agg_concat().into()) } - pub fn __abs__(&self) -> PyResult { - use functions::numeric::abs; - Ok(abs(self.into()).into()) - } pub fn __add__(&self, other: &Self) -> PyResult { Ok(crate::binary_op(crate::Operator::Plus, self.into(), other.expr.clone()).into()) } diff --git a/src/daft-functions/Cargo.toml b/src/daft-functions/Cargo.toml index b965e9f417..9be2a86dc2 100644 --- a/src/daft-functions/Cargo.toml +++ b/src/daft-functions/Cargo.toml @@ -2,6 +2,7 @@ arrow2 = {workspace = true} base64 = {workspace = true} common-error = {path = "../common/error", default-features = false} +common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"} common-io-config = {path = "../common/io-config", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index 0976f17c21..0a8486864e 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -31,7 +31,6 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(hash::python::hash, parent)?)?; parent.add_function(wrap_pyfunction_bound!(minhash::python::minhash, parent)?)?; - parent.add_function(wrap_pyfunction_bound!(numeric::cbrt::python::cbrt, parent)?)?; parent.add_function(wrap_pyfunction_bound!( to_struct::python::to_struct, parent @@ -46,6 +45,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { )?)?; parent.add_function(wrap_pyfunction_bound!(uri::python::url_download, parent)?)?; parent.add_function(wrap_pyfunction_bound!(uri::python::url_upload, parent)?)?; + numeric::register_modules(parent)?; image::register_modules(parent)?; float::register_modules(parent)?; temporal::register_modules(parent)?; diff --git a/src/daft-functions/src/numeric/abs.rs b/src/daft-functions/src/numeric/abs.rs new file mode 100644 index 0000000000..f054950e0f --- /dev/null +++ b/src/daft-functions/src/numeric/abs.rs @@ -0,0 +1,48 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Abs {} + +#[typetag::serde] +impl ScalarUDF for Abs { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "abs" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::abs) + } +} + +pub fn abs(input: ExprRef) -> ExprRef { + ScalarFunction::new(Abs {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "abs")] +pub fn py_abs(expr: PyExpr) -> PyResult { + Ok(abs(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/cbrt.rs b/src/daft-functions/src/numeric/cbrt.rs index 7f2e635689..c9b4e9286f 100644 --- a/src/daft-functions/src/numeric/cbrt.rs +++ b/src/daft-functions/src/numeric/cbrt.rs @@ -1,13 +1,17 @@ -use common_error::{DaftError, DaftResult}; +use common_error::DaftResult; use daft_core::prelude::*; -use daft_dsl::{functions::ScalarUDF, ExprRef}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -struct CbrtFunction; +pub struct Cbrt; +use super::{evaluate_single_numeric, to_field_single_floating}; #[typetag::serde] -impl ScalarUDF for CbrtFunction { +impl ScalarUDF for Cbrt { fn as_any(&self) -> &dyn std::any::Any { self } @@ -17,41 +21,26 @@ impl ScalarUDF for CbrtFunction { } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { - match inputs { - [input] => { - let field = input.to_field(schema)?; - let dtype = field.dtype.to_floating_representation()?; - Ok(Field::new(field.name, dtype)) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } + to_field_single_floating(self, inputs, schema) } fn evaluate(&self, inputs: &[Series]) -> DaftResult { - match inputs { - [input] => input.cbrt(), - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } + evaluate_single_numeric(inputs, Series::cbrt) } } -#[cfg(feature = "python")] -pub mod python { - use daft_dsl::{functions::ScalarFunction, python::PyExpr, ExprRef}; - use pyo3::{pyfunction, PyResult}; - - use super::CbrtFunction; +pub fn cbrt(input: ExprRef) -> ExprRef { + ScalarFunction::new(Cbrt {}, vec![input]).into() +} - #[pyfunction] - pub fn cbrt(expr: PyExpr) -> PyResult { - let scalar_function = ScalarFunction::new(CbrtFunction, vec![expr.into()]); - let expr = ExprRef::from(scalar_function); - Ok(expr.into()) - } +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "cbrt")] +pub fn py_cbrt(expr: PyExpr) -> PyResult { + Ok(cbrt(expr.into()).into()) } diff --git a/src/daft-functions/src/numeric/ceil.rs b/src/daft-functions/src/numeric/ceil.rs new file mode 100644 index 0000000000..26c37bec6b --- /dev/null +++ b/src/daft-functions/src/numeric/ceil.rs @@ -0,0 +1,49 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Ceil {} + +#[typetag::serde] +impl ScalarUDF for Ceil { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "ceil" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::ceil) + } +} + +pub fn ceil(input: ExprRef) -> ExprRef { + ScalarFunction::new(Ceil {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "ceil")] +pub fn py_ceil(expr: PyExpr) -> PyResult { + Ok(ceil(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/exp.rs b/src/daft-functions/src/numeric/exp.rs new file mode 100644 index 0000000000..abde081b46 --- /dev/null +++ b/src/daft-functions/src/numeric/exp.rs @@ -0,0 +1,67 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +use super::evaluate_single_numeric; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Exp {} + +#[typetag::serde] +impl ScalarUDF for Exp { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "exp" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to compute exp to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::exp) + } +} + +pub fn exp(input: ExprRef) -> ExprRef { + ScalarFunction::new(Exp {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "exp")] +pub fn py_exp(expr: PyExpr) -> PyResult { + Ok(exp(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/floor.rs b/src/daft-functions/src/numeric/floor.rs new file mode 100644 index 0000000000..36ec365e0f --- /dev/null +++ b/src/daft-functions/src/numeric/floor.rs @@ -0,0 +1,50 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Floor {} + +#[typetag::serde] +impl ScalarUDF for Floor { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "floor" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::floor) + } +} + +pub fn floor(input: ExprRef) -> ExprRef { + ScalarFunction::new(Floor {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "floor")] +pub fn py_floor(expr: PyExpr) -> PyResult { + Ok(floor(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/log.rs b/src/daft-functions/src/numeric/log.rs new file mode 100644 index 0000000000..7aecb2de56 --- /dev/null +++ b/src/daft-functions/src/numeric/log.rs @@ -0,0 +1,142 @@ +use common_error::{DaftError, DaftResult}; +use common_hashable_float_wrapper::FloatWrapper; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +// super annoying, but using an enum with typetag::serde doesn't work with bincode because it uses Deserializer::deserialize_identifier +macro_rules! log { + ($name:ident, $variant:ident) => { + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] + pub struct $variant; + + #[typetag::serde] + impl ScalarUDF for $variant { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + stringify!($name) + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to log to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::$name) + } + } + + pub fn $name(input: ExprRef) -> ExprRef { + ScalarFunction::new($variant, vec![input]).into() + } + }; +} + +log!(log2, Log2); +log!(log10, Log10); +log!(ln, Ln); + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Log(FloatWrapper); + +#[typetag::serde] +impl ScalarUDF for Log { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "log" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to log to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, |x| x.log(self.0 .0)) + } +} + +pub fn log(input: ExprRef, base: f64) -> ExprRef { + ScalarFunction::new(Log(FloatWrapper(base)), vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::evaluate_single_numeric; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "log2")] +pub fn py_log2(expr: PyExpr) -> PyResult { + Ok(log2(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "log10")] +pub fn py_log10(expr: PyExpr) -> PyResult { + Ok(log10(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "log")] +pub fn py_log(expr: PyExpr, base: f64) -> PyResult { + Ok(log(expr.into(), base).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "ln")] +pub fn py_ln(expr: PyExpr) -> PyResult { + Ok(ln(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/mod.rs b/src/daft-functions/src/numeric/mod.rs index f78822dd52..28d0e50cad 100644 --- a/src/daft-functions/src/numeric/mod.rs +++ b/src/daft-functions/src/numeric/mod.rs @@ -1 +1,104 @@ +pub mod abs; pub mod cbrt; +pub mod ceil; +pub mod exp; +pub mod floor; +pub mod log; +pub mod round; +pub mod sign; +pub mod sqrt; +pub mod trigonometry; + +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{functions::ScalarUDF, ExprRef}; +#[cfg(feature = "python")] +use pyo3::prelude::*; + +#[cfg(feature = "python")] +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_function(wrap_pyfunction_bound!(abs::py_abs, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(cbrt::py_cbrt, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(ceil::py_ceil, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(exp::py_exp, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(floor::py_floor, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(log::py_log2, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(log::py_log10, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(log::py_log, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(log::py_ln, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(round::py_round, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(sign::py_sign, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(sqrt::py_sqrt, parent)?)?; + + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_sin, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_cos, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_tan, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_cot, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arcsin, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arccos, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arctan, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_radians, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_degrees, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arctanh, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arccosh, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arcsinh, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arctan2, parent)?)?; + + Ok(()) +} + +fn to_field_single_numeric( + f: &dyn ScalarUDF, + inputs: &[ExprRef], + schema: &Schema, +) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + let field = inputs.first().unwrap().to_field(schema)?; + if !field.dtype.is_numeric() { + return Err(DaftError::TypeError(format!( + "Expected input to {} to be numeric, got {}", + f.name(), + field.dtype + ))); + } + Ok(field) +} + +fn to_field_single_floating( + f: &dyn ScalarUDF, + inputs: &[ExprRef], + schema: &Schema, +) -> DaftResult { + match inputs { + [first] => { + let field = first.to_field(schema)?; + let dtype = field.dtype.to_floating_representation()?; + Ok(Field::new(field.name, dtype)) + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg for {}, got {}", + f.name(), + inputs.len() + ))), + } +} +fn evaluate_single_numeric DaftResult>( + inputs: &[Series], + func: F, +) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + func(inputs.first().unwrap()) +} diff --git a/src/daft-functions/src/numeric/round.rs b/src/daft-functions/src/numeric/round.rs new file mode 100644 index 0000000000..395b0ee696 --- /dev/null +++ b/src/daft-functions/src/numeric/round.rs @@ -0,0 +1,59 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Round { + decimal: i32, +} + +#[typetag::serde] +impl ScalarUDF for Round { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "round" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, |s| s.round(self.decimal)) + } +} + +pub fn round(input: ExprRef, decimal: i32) -> ExprRef { + ScalarFunction::new(Round { decimal }, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "round")] +pub fn py_round(expr: PyExpr, decimal: i32) -> PyResult { + use pyo3::exceptions::PyValueError; + + if decimal < 0 { + return Err(PyValueError::new_err(format!( + "decimal can not be negative: {decimal}" + ))); + } + Ok(round(expr.into(), decimal).into()) +} diff --git a/src/daft-functions/src/numeric/sign.rs b/src/daft-functions/src/numeric/sign.rs new file mode 100644 index 0000000000..a58b7f294d --- /dev/null +++ b/src/daft-functions/src/numeric/sign.rs @@ -0,0 +1,50 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Sign {} + +#[typetag::serde] +impl ScalarUDF for Sign { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "sign" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::sign) + } +} + +pub fn sign(input: ExprRef) -> ExprRef { + ScalarFunction::new(Sign {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "sign")] +pub fn py_sign(expr: PyExpr) -> PyResult { + Ok(sign(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/sqrt.rs b/src/daft-functions/src/numeric/sqrt.rs new file mode 100644 index 0000000000..11766e4f17 --- /dev/null +++ b/src/daft-functions/src/numeric/sqrt.rs @@ -0,0 +1,51 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +use super::{evaluate_single_numeric, to_field_single_floating}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Sqrt {} + +#[typetag::serde] +impl ScalarUDF for Sqrt { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "sqrt" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_floating(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::sqrt) + } +} + +pub fn sqrt(input: ExprRef) -> ExprRef { + ScalarFunction::new(Sqrt {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "sqrt")] +pub fn py_sqrt(expr: PyExpr) -> PyResult { + Ok(sqrt(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/trigonometry.rs b/src/daft-functions/src/numeric/trigonometry.rs new file mode 100644 index 0000000000..9a47875596 --- /dev/null +++ b/src/daft-functions/src/numeric/trigonometry.rs @@ -0,0 +1,222 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + array::ops::trigonometry::TrigonometricFunction, + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +use super::evaluate_single_numeric; + +// super annoying, but using an enum with typetag::serde doesn't work with bincode because it uses Deserializer::deserialize_identifier +macro_rules! trigonometry { + ($name:ident, $variant:ident) => { + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] + pub struct $variant; + + #[typetag::serde] + impl ScalarUDF for $variant { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + TrigonometricFunction::$variant.fn_name() + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to trigonometry to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, |s| { + s.trigonometry(&TrigonometricFunction::$variant) + }) + } + } + + pub fn $name(input: ExprRef) -> ExprRef { + ScalarFunction::new($variant, vec![input]).into() + } + }; +} + +trigonometry!(sin, Sin); +trigonometry!(cos, Cos); +trigonometry!(tan, Tan); +trigonometry!(cot, Cot); +trigonometry!(arcsin, ArcSin); +trigonometry!(arccos, ArcCos); +trigonometry!(arctan, ArcTan); +trigonometry!(radians, Radians); +trigonometry!(degrees, Degrees); +trigonometry!(arctanh, ArcTanh); +trigonometry!(arccosh, ArcCosh); +trigonometry!(arcsinh, ArcSinh); + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Atan2 {} + +#[typetag::serde] +impl ScalarUDF for Atan2 { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "atan2" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 2 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))); + } + let field1 = inputs.first().unwrap().to_field(schema)?; + let field2 = inputs.get(1).unwrap().to_field(schema)?; + let dtype = match (field1.dtype, field2.dtype) { + (DataType::Float32, DataType::Float32) => DataType::Float32, + (dt1, dt2) if dt1.is_numeric() && dt2.is_numeric() => DataType::Float64, + (dt1, dt2) => { + return Err(DaftError::TypeError(format!( + "Expected inputs to atan2 to be numeric, got {} and {}", + dt1, dt2 + ))) + } + }; + Ok(Field::new(field1.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [x, y] => x.atan2(y), + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +pub fn atan2(x: ExprRef, y: ExprRef) -> ExprRef { + ScalarFunction::new(Atan2 {}, vec![x, y]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "sin")] +pub fn py_sin(expr: PyExpr) -> PyResult { + Ok(sin(expr.into()).into()) +} +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "cos")] +pub fn py_cos(expr: PyExpr) -> PyResult { + Ok(cos(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "tan")] +pub fn py_tan(expr: PyExpr) -> PyResult { + Ok(tan(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "cot")] +pub fn py_cot(expr: PyExpr) -> PyResult { + Ok(cot(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arcsin")] +pub fn py_arcsin(expr: PyExpr) -> PyResult { + Ok(arcsin(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arccos")] +pub fn py_arccos(expr: PyExpr) -> PyResult { + Ok(arccos(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arctan")] +pub fn py_arctan(expr: PyExpr) -> PyResult { + Ok(arctan(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "radians")] +pub fn py_radians(expr: PyExpr) -> PyResult { + Ok(radians(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "degrees")] +pub fn py_degrees(expr: PyExpr) -> PyResult { + Ok(degrees(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arctanh")] +pub fn py_arctanh(expr: PyExpr) -> PyResult { + Ok(arctanh(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arccosh")] +pub fn py_arccosh(expr: PyExpr) -> PyResult { + Ok(arccosh(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arcsinh")] +pub fn py_arcsinh(expr: PyExpr) -> PyResult { + Ok(arcsinh(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arctan2")] +pub fn py_arctan2(x: PyExpr, y: PyExpr) -> PyResult { + Ok(atan2(x.into(), y.into()).into()) +} diff --git a/src/daft-sql/src/modules/numeric.rs b/src/daft-sql/src/modules/numeric.rs index 078878faef..197d958860 100644 --- a/src/daft-sql/src/modules/numeric.rs +++ b/src/daft-sql/src/modules/numeric.rs @@ -1,6 +1,17 @@ -use daft_dsl::{ - functions::{self, numeric::NumericExpr}, - ExprRef, LiteralValue, +use daft_dsl::{ExprRef, LiteralValue}; +use daft_functions::numeric::{ + abs::abs, + ceil::ceil, + exp::exp, + floor::floor, + log::{ln, log, log10, log2}, + round::round, + sign::sign, + sqrt::sqrt, + trigonometry::{ + arccos, arccosh, arcsin, arcsinh, arctan, arctanh, atan2, cos, cot, degrees, radians, sin, + tan, + }, }; use super::SQLModule; @@ -13,38 +24,62 @@ use crate::{ pub struct SQLModuleNumeric; -/// SQLModule for FunctionExpr::Numeric impl SQLModule for SQLModuleNumeric { fn register(parent: &mut SQLFunctions) { - use NumericExpr::*; - parent.add_fn("abs", Abs); - parent.add_fn("ceil", Ceil); - parent.add_fn("floor", Floor); - parent.add_fn("sign", Sign); - parent.add_fn("round", Round(0)); - parent.add_fn("sqrt", Sqrt); - parent.add_fn("sin", Sin); - parent.add_fn("cos", Cos); - parent.add_fn("tan", Tan); - parent.add_fn("cot", Cot); - parent.add_fn("asin", ArcSin); - parent.add_fn("acos", ArcCos); - parent.add_fn("atan", ArcTan); - parent.add_fn("atan2", ArcTan2); - parent.add_fn("radians", Radians); - parent.add_fn("degrees", Degrees); - parent.add_fn("log2", Log2); - parent.add_fn("log10", Log10); - // parent.add("log", f(Log(FloatWrapper(0.0)))); - parent.add_fn("ln", Ln); - parent.add_fn("exp", Exp); - parent.add_fn("atanh", ArcTanh); - parent.add_fn("acosh", ArcCosh); - parent.add_fn("asinh", ArcSinh); + parent.add_fn("abs", SQLNumericExpr::Abs); + parent.add_fn("ceil", SQLNumericExpr::Ceil); + parent.add_fn("floor", SQLNumericExpr::Floor); + parent.add_fn("sign", SQLNumericExpr::Sign); + parent.add_fn("round", SQLNumericExpr::Round); + parent.add_fn("sqrt", SQLNumericExpr::Sqrt); + parent.add_fn("sin", SQLNumericExpr::Sin); + parent.add_fn("cos", SQLNumericExpr::Cos); + parent.add_fn("tan", SQLNumericExpr::Tan); + parent.add_fn("cot", SQLNumericExpr::Cot); + parent.add_fn("asin", SQLNumericExpr::ArcSin); + parent.add_fn("acos", SQLNumericExpr::ArcCos); + parent.add_fn("atan", SQLNumericExpr::ArcTan); + parent.add_fn("atan2", SQLNumericExpr::ArcTan2); + parent.add_fn("radians", SQLNumericExpr::Radians); + parent.add_fn("degrees", SQLNumericExpr::Degrees); + parent.add_fn("log2", SQLNumericExpr::Log2); + parent.add_fn("log10", SQLNumericExpr::Log10); + parent.add_fn("log", SQLNumericExpr::Log); + parent.add_fn("ln", SQLNumericExpr::Ln); + parent.add_fn("exp", SQLNumericExpr::Exp); + parent.add_fn("atanh", SQLNumericExpr::ArcTanh); + parent.add_fn("acosh", SQLNumericExpr::ArcCosh); + parent.add_fn("asinh", SQLNumericExpr::ArcSinh); } } +enum SQLNumericExpr { + Abs, + Ceil, + Exp, + Floor, + Round, + Sign, + Sqrt, + Sin, + Cos, + Tan, + Cot, + ArcSin, + ArcCos, + ArcTan, + ArcTan2, + Radians, + Degrees, + Log, + Log2, + Log10, + Ln, + ArcTanh, + ArcCosh, + ArcSinh, +} -impl SQLFunction for NumericExpr { +impl SQLFunction for SQLNumericExpr { fn to_expr( &self, inputs: &[sqlparser::ast::FunctionArg], @@ -54,27 +89,26 @@ impl SQLFunction for NumericExpr { to_expr(self, inputs.as_slice()) } } -fn to_expr(expr: &NumericExpr, args: &[ExprRef]) -> SQLPlannerResult { - use functions::numeric::*; - use NumericExpr::*; + +fn to_expr(expr: &SQLNumericExpr, args: &[ExprRef]) -> SQLPlannerResult { match expr { - Abs => { + SQLNumericExpr::Abs => { ensure!(args.len() == 1, "abs takes exactly one argument"); Ok(abs(args[0].clone())) } - Ceil => { + SQLNumericExpr::Ceil => { ensure!(args.len() == 1, "ceil takes exactly one argument"); Ok(ceil(args[0].clone())) } - Floor => { + SQLNumericExpr::Floor => { ensure!(args.len() == 1, "floor takes exactly one argument"); Ok(floor(args[0].clone())) } - Sign => { + SQLNumericExpr::Sign => { ensure!(args.len() == 1, "sign takes exactly one argument"); Ok(sign(args[0].clone())) } - Round(_) => { + SQLNumericExpr::Round => { ensure!(args.len() == 2, "round takes exactly two arguments"); let precision = match args[1].as_ref().as_literal() { Some(LiteralValue::Int32(i)) => *i, @@ -84,63 +118,63 @@ fn to_expr(expr: &NumericExpr, args: &[ExprRef]) -> SQLPlannerResult { }; Ok(round(args[0].clone(), precision)) } - Sqrt => { + SQLNumericExpr::Sqrt => { ensure!(args.len() == 1, "sqrt takes exactly one argument"); Ok(sqrt(args[0].clone())) } - Sin => { + SQLNumericExpr::Sin => { ensure!(args.len() == 1, "sin takes exactly one argument"); Ok(sin(args[0].clone())) } - Cos => { + SQLNumericExpr::Cos => { ensure!(args.len() == 1, "cos takes exactly one argument"); Ok(cos(args[0].clone())) } - Tan => { + SQLNumericExpr::Tan => { ensure!(args.len() == 1, "tan takes exactly one argument"); Ok(tan(args[0].clone())) } - Cot => { + SQLNumericExpr::Cot => { ensure!(args.len() == 1, "cot takes exactly one argument"); Ok(cot(args[0].clone())) } - ArcSin => { + SQLNumericExpr::ArcSin => { ensure!(args.len() == 1, "asin takes exactly one argument"); Ok(arcsin(args[0].clone())) } - ArcCos => { + SQLNumericExpr::ArcCos => { ensure!(args.len() == 1, "acos takes exactly one argument"); Ok(arccos(args[0].clone())) } - ArcTan => { + SQLNumericExpr::ArcTan => { ensure!(args.len() == 1, "atan takes exactly one argument"); Ok(arctan(args[0].clone())) } - ArcTan2 => { + SQLNumericExpr::ArcTan2 => { ensure!(args.len() == 2, "atan2 takes exactly two arguments"); - Ok(arctan2(args[0].clone(), args[1].clone())) + Ok(atan2(args[0].clone(), args[1].clone())) } - Degrees => { + SQLNumericExpr::Degrees => { ensure!(args.len() == 1, "degrees takes exactly one argument"); Ok(degrees(args[0].clone())) } - Radians => { + SQLNumericExpr::Radians => { ensure!(args.len() == 1, "radians takes exactly one argument"); Ok(radians(args[0].clone())) } - Log2 => { + SQLNumericExpr::Log2 => { ensure!(args.len() == 1, "log2 takes exactly one argument"); Ok(log2(args[0].clone())) } - Log10 => { + SQLNumericExpr::Log10 => { ensure!(args.len() == 1, "log10 takes exactly one argument"); Ok(log10(args[0].clone())) } - Ln => { + SQLNumericExpr::Ln => { ensure!(args.len() == 1, "ln takes exactly one argument"); Ok(ln(args[0].clone())) } - Log(_) => { + SQLNumericExpr::Log => { ensure!(args.len() == 2, "log takes exactly two arguments"); let base = args[1] .as_literal() @@ -158,19 +192,19 @@ fn to_expr(expr: &NumericExpr, args: &[ExprRef]) -> SQLPlannerResult { Ok(log(args[0].clone(), base)) } - Exp => { + SQLNumericExpr::Exp => { ensure!(args.len() == 1, "exp takes exactly one argument"); Ok(exp(args[0].clone())) } - ArcTanh => { + SQLNumericExpr::ArcTanh => { ensure!(args.len() == 1, "atanh takes exactly one argument"); Ok(arctanh(args[0].clone())) } - ArcCosh => { + SQLNumericExpr::ArcCosh => { ensure!(args.len() == 1, "acosh takes exactly one argument"); Ok(arccosh(args[0].clone())) } - ArcSinh => { + SQLNumericExpr::ArcSinh => { ensure!(args.len() == 1, "asinh takes exactly one argument"); Ok(arcsinh(args[0].clone())) } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index ff92b38b85..cf82f72743 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -4,12 +4,10 @@ use common_error::DaftResult; use daft_core::prelude::*; use daft_dsl::{ col, - functions::{ - numeric::{ceil, floor}, - utf8::{ilike, like}, - }, + functions::utf8::{ilike, like}, has_agg, lit, literals_to_series, null_lit, Expr, ExprRef, LiteralValue, Operator, }; +use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ From ff4a0a75e915da7ede6c226b42573ee057ab7e0d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:34:35 +0000 Subject: [PATCH 13/35] Bump isbang/compose-action from 2.0.0 to 2.0.2 (#2887) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [isbang/compose-action](https://github.com/isbang/compose-action) from 2.0.0 to 2.0.2.
Release notes

Sourced from isbang/compose-action's releases.

v2.0.2

Release Summary

This release introduces new tests for the attach-dependencies feature, along with support for absolute paths in the compose-file input. Documentation for actions and workflows has been updated, and several dependencies related to GitHub actions and npm development have been bumped to ensure better performance and stability.

No breaking changes have been introduced.

What's Changed

Full Changelog: https://github.com/hoverkraft-tech/compose-action/compare/v2.0.1...v2.0.2

v2.0.1

What's Changed

Full Changelog: https://github.com/hoverkraft-tech/compose-action/compare/v2.0.0...v2.0.1

Commits
  • f1ca7fe chore(deps): bump hoverkraft-tech/ci-github-common
  • 89f869f docs: update actions and workflows documentation
  • 4d0cecc fix: support absolute path for compose-file input
  • ce6f83d chore(deps-dev): bump eslint-plugin-github
  • 7e2bd76 chore(deps): bump hoverkraft-tech/ci-github-nodejs
  • 31f7d2c chore(deps): bump hoverkraft-tech/ci-github-common
  • 8391f9b docs: update actions and workflows documentation
  • 0fce44d chore(deps): bump hoverkraft-tech/ci-github-common
  • 89eb95b docs: add basic usage exemple
  • f7ac24c ci: add test for attach-dependencies
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=isbang/compose-action&package-manager=github_actions&previous-version=2.0.0&new-version=2.0.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/nightlies-tests.yml | 2 +- .github/workflows/python-package.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nightlies-tests.yml b/.github/workflows/nightlies-tests.yml index 98cc20a7f2..92532242f7 100644 --- a/.github/workflows/nightlies-tests.yml +++ b/.github/workflows/nightlies-tests.yml @@ -127,7 +127,7 @@ jobs: role-to-assume: ${{ secrets.ACTIONS_AWS_ROLE_ARN }} role-session-name: DaftPythonPackageGitHubWorkflow - name: Spin up IO services - uses: isbang/compose-action@v2.0.0 + uses: isbang/compose-action@v2.0.2 with: compose-file: ./tests/integration/io/docker-compose/docker-compose.yml down-flags: --volumes diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 57062f994e..2b185084c2 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -330,7 +330,7 @@ jobs: mkdir -p /tmp/daft-integration-testing/nginx chmod +rw /tmp/daft-integration-testing/nginx - name: Spin up IO services - uses: isbang/compose-action@v2.0.0 + uses: isbang/compose-action@v2.0.2 with: compose-file: ./tests/integration/io/docker-compose/docker-compose.yml down-flags: --volumes @@ -427,7 +427,7 @@ jobs: # workload_identity_provider: ${{ secrets.ACTIONS_GCP_WORKLOAD_IDENTITY_PROVIDER }} # service_account: ${{ secrets.ACTIONS_GCP_SERVICE_ACCOUNT }} - name: Spin up IO services - uses: isbang/compose-action@v2.0.0 + uses: isbang/compose-action@v2.0.2 with: compose-file: ./tests/integration/io/docker-compose/docker-compose.yml down-flags: --volumes From 26beb3fe84188c51b9222b644ddfee34592e2ff6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:35:55 +0000 Subject: [PATCH 14/35] Bump astral-sh/setup-uv from 2 to 3 (#2888) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [astral-sh/setup-uv](https://github.com/astral-sh/setup-uv) from 2 to 3.
Release notes

Sourced from astral-sh/setup-uv's releases.

v3.1.1 🌈 update known checksums for 0.4.15

Changes

🧰 Maintenance

📚 Documentation

v3.0.0 🌈 Set the cache-dependency-glob default to **/uv.lock

Changes

With this release cache-dependency-glob defaults to **/uv.lock. This is in line with what most users would expect and also mirrors the default behaviors for setup-python which use **/requirements.txt, **/Pipfile.lock or **/poetry.lock.

The previous default led to the cache being created only once and never invalidated or updated even when the dependencies changed.

This change only affects you if you are using enable-cache: true without specifying cache-dependency-glob. The only behavioral change you might see is one time cache miss.

Learn more about cache-dependency-glob in the README section.

🚨 Breaking changes

🧰 Maintenance

⬆️ Dependency updates

v2.1.2 🌈 update known checksums for 0.4.10

Changes

🧰 Maintenance

⬆️ Dependency updates

... (truncated)

Commits
  • abac0ce Use runner label selfhosted-ubuntu-arm64 (#96)
  • 03e245b chore: update known checksums for 0.4.15 (#95)
  • aeb4649 Set tool(-bin) dir and add to PATH (#87)
  • dbb680f chore: update known checksums for 0.4.14 (#94)
  • 8de9ba9 chore: update known checksums for 0.4.13 (#93)
  • 118bd87 Bump peter-evans/create-pull-request from 7.0.3 to 7.0.5 (#91)
  • adcf5c8 Fix a typo SHA265 → SHA256 (#90)
  • 0c326cb chore: update known checksums for 0.4.12 (#86)
  • 8205eab Update version in README from v2 to v3 (#85)
  • ce0062a Add support for semver version ranges (#82)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=astral-sh/setup-uv&package-manager=github_actions&previous-version=2&new-version=3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/nightlies-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nightlies-tests.yml b/.github/workflows/nightlies-tests.yml index 92532242f7..1dabd876e9 100644 --- a/.github/workflows/nightlies-tests.yml +++ b/.github/workflows/nightlies-tests.yml @@ -32,7 +32,7 @@ jobs: fetch-depth: 0 - name: Install uv - uses: astral-sh/setup-uv@v2 + uses: astral-sh/setup-uv@v3 - name: Set up Python ${{ env.PYTHON_VERSION }} uses: actions/setup-python@v5 with: @@ -103,7 +103,7 @@ jobs: if: runner.os == 'macos' - name: Install uv - uses: astral-sh/setup-uv@v2 + uses: astral-sh/setup-uv@v3 - name: Set up Python ${{ env.PYTHON_VERSION }} uses: actions/setup-python@v5 with: From 29be743ae682b88c5bd76cfe65bfce55f5945e51 Mon Sep 17 00:00:00 2001 From: Desmond Cheong Date: Mon, 23 Sep 2024 10:47:31 -0700 Subject: [PATCH 15/35] [BUG] Fix partitioning SQL scans on empty tables (#2885) When scanning an empty SQL table, if a user specifies `num_partitions > 1`, then we get `TypeError: unsupported operand type(s) for -: 'NoneType' and 'NoneType'` or `Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning`. These are both results of attempts to partition the scan using a column with no data. Despite the table not having rows, we attempt to fulfill the user's request to use `num_partitions > 2`, only to run into errors because there is no min and max value available compute partition range sizes on. Furthermore, in some cases SQL databases do not return type information when a table has no rows, so an empty integer column might be read as an empty string column which cannot be used for partitioning. We fix this by capping the number of scan tasks by the number of rows in the scan table. If there are no rows, we don't attempt partitioning, and we simply generate 0 scan tasks. --- daft/sql/sql_scan.py | 3 +++ tests/conftest.py | 2 +- tests/integration/sql/conftest.py | 23 +++++++++++++++++++++++ tests/integration/sql/test_sql.py | 24 +++++++++++++++++++++++- 4 files changed, 50 insertions(+), 2 deletions(-) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 40035c12f7..4f0f9a35c7 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -74,6 +74,8 @@ def multiline_display(self) -> list[str]: def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: total_rows, total_size, num_scan_tasks = self._get_size_estimates() + if num_scan_tasks == 0: + return iter(()) if num_scan_tasks == 1 or self._partition_col is None: return self._single_scan_task(pushdowns, total_rows, total_size) @@ -139,6 +141,7 @@ def _get_size_estimates(self) -> tuple[int, float, int]: if self._num_partitions is None else self._num_partitions ) + num_scan_tasks = min(num_scan_tasks, total_rows) return total_rows, total_size, num_scan_tasks def _get_num_rows(self) -> int: diff --git a/tests/conftest.py b/tests/conftest.py index c8c8b3b52b..d46f62f89e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -148,7 +148,7 @@ def assert_df_equals( sort_key_list: list[str] = [sort_key] if isinstance(sort_key, str) else sort_key for key in sort_key_list: assert key in daft_pd_df.columns, ( - f"DaFt Dataframe missing key: {key}\nNOTE: This doesn't necessarily mean your code is " + f"Daft Dataframe missing key: {key}\nNOTE: This doesn't necessarily mean your code is " "breaking, but our testing utilities require sorting on this key in order to compare your " "Dataframe against the expected Pandas Dataframe." ) diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index 63d58032a9..f5c01dccc6 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -28,6 +28,7 @@ "mysql+pymysql://username:password@localhost:3306/mysql", ] TEST_TABLE_NAME = "example" +EMPTY_TEST_TABLE_NAME = "empty_table" @pytest.fixture(scope="session", params=[{"num_rows": 200}]) @@ -54,6 +55,28 @@ def test_db(request: pytest.FixtureRequest, generated_data: pd.DataFrame) -> Gen yield db_url +@pytest.fixture(scope="session", params=URLS) +def empty_test_db(request: pytest.FixtureRequest) -> Generator[str, None, None]: + data = pd.DataFrame( + { + "id": pd.Series(dtype="int"), + "string_col": pd.Series(dtype="str"), + } + ) + db_url = request.param + engine = create_engine(db_url) + metadata = MetaData() + table = Table( + EMPTY_TEST_TABLE_NAME, + metadata, + Column("id", Integer), + Column("string_col", String(50)), + ) + metadata.create_all(engine) + data.to_sql(table.name, con=engine, if_exists="replace", index=False) + yield db_url + + @tenacity.retry(stop=tenacity.stop_after_delay(10), wait=tenacity.wait_fixed(5), reraise=True) def setup_database(db_url: str, data: pd.DataFrame) -> None: engine = create_engine(db_url) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 602866dcbd..ff02ebaac4 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -9,7 +9,7 @@ import daft from tests.conftest import assert_df_equals -from tests.integration.sql.conftest import TEST_TABLE_NAME +from tests.integration.sql.conftest import EMPTY_TEST_TABLE_NAME, TEST_TABLE_NAME @pytest.fixture(scope="session") @@ -65,6 +65,28 @@ def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id") +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [0, 1, 2]) +@pytest.mark.parametrize("partition_col", ["id", "string_col"]) +def test_sql_partitioned_read_on_empty_table(empty_test_db, num_partitions, partition_col) -> None: + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + df = daft.read_sql( + f"SELECT * FROM {EMPTY_TEST_TABLE_NAME}", + empty_test_db, + partition_col=partition_col, + num_partitions=num_partitions, + schema={"id": daft.DataType.int64(), "string_col": daft.DataType.string()}, + ) + assert df.num_partitions() == 1 + empty_pdf = pd.read_sql_query( + f"SELECT * FROM {EMPTY_TEST_TABLE_NAME}", empty_test_db, dtype={"id": "int64", "string_col": "str"} + ) + assert_df_equals(df.to_pandas(), empty_pdf, sort_key="id") + + @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) def test_sql_partitioned_read_with_non_uniformly_distributed_column(test_db, num_partitions, pdf) -> None: From fd42281351cb4477f03987cb6539bb46e4200dde Mon Sep 17 00:00:00 2001 From: Arik Mitschang Date: Mon, 23 Sep 2024 15:26:45 -0400 Subject: [PATCH 16/35] [FEAT] [SQL] Enable SQL query to run on callers scoped variables (#2864) This builds a catalog based on the python globals and locals visible to the caller at the point where the `sql` query function is called, in the case where a catalog is not supplied. Otherwise, the catalog is final and must contain necessary tables. resolves #2740 --- daft/daft/__init__.pyi | 1 + daft/sql/sql.py | 39 ++++++++++++++++++++++++++++- src/daft-sql/src/catalog.rs | 7 ++++++ src/daft-sql/src/python.rs | 5 ++++ tests/sql/test_sql.py | 50 +++++++++++++++++++++++++++++++++++++ 5 files changed, 101 insertions(+), 1 deletion(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 7c8f04a7fc..2647b23c58 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1277,6 +1277,7 @@ class PyCatalog: @staticmethod def new() -> PyCatalog: ... def register_table(self, name: str, logical_plan_builder: LogicalPlanBuilder) -> None: ... + def copy_from(self, other: PyCatalog) -> None: ... class PySeries: @staticmethod diff --git a/daft/sql/sql.py b/daft/sql/sql.py index a5334a28b9..987a9baeb0 100644 --- a/daft/sql/sql.py +++ b/daft/sql/sql.py @@ -1,11 +1,15 @@ # isort: dont-add-import: from __future__ import annotations +import inspect +from typing import Optional, overload + from daft.api_annotations import PublicAPI from daft.context import get_context from daft.daft import PyCatalog as _PyCatalog from daft.daft import sql as _sql from daft.daft import sql_expr as _sql_expr from daft.dataframe import DataFrame +from daft.exceptions import DaftCoreException from daft.expressions import Expression from daft.logical.builder import LogicalPlanBuilder @@ -28,24 +32,57 @@ def __init__(self, tables: dict) -> None: def __str__(self) -> str: return str(self._catalog) + def _copy_from(self, other: "SQLCatalog") -> None: + self._catalog.copy_from(other._catalog) + @PublicAPI def sql_expr(sql: str) -> Expression: return Expression._from_pyexpr(_sql_expr(sql)) +@overload +def sql(sql: str) -> DataFrame: ... + + +@overload +def sql(sql: str, catalog: SQLCatalog, register_globals: bool = ...) -> DataFrame: ... + + @PublicAPI -def sql(sql: str, catalog: SQLCatalog) -> DataFrame: +def sql(sql: str, catalog: Optional[SQLCatalog] = None, register_globals: bool = True) -> DataFrame: """Create a DataFrame from an SQL query. EXPERIMENTAL: This features is early in development and will change. Args: sql (str): SQL query to execute + catalog (SQLCatalog, optional): Catalog of tables to use in the query. + Defaults to None, in which case a catalog will be built from variables + in the callers scope. + register_globals (bool, optional): Whether to incorporate global + variables into the supplied catalog, in which case a copy of the + catalog will be made and the original not modified. Defaults to True. Returns: DataFrame: Dataframe containing the results of the query """ + if register_globals: + try: + # Caller is back from func, analytics, annotation + caller_frame = inspect.currentframe().f_back.f_back.f_back # type: ignore + caller_vars = {**caller_frame.f_globals, **caller_frame.f_locals} # type: ignore + except AttributeError as exc: + # some interpreters might not implement currentframe; all reasonable + # errors above should be AttributeError + raise DaftCoreException("Cannot get caller environment, please provide a catalog") from exc + catalog_ = SQLCatalog({k: v for k, v in caller_vars.items() if isinstance(v, DataFrame)}) + if catalog is not None: + catalog_._copy_from(catalog) + catalog = catalog_ + elif catalog is None: + raise DaftCoreException("Must supply a catalog if register_globals is False") + planning_config = get_context().daft_planning_config _py_catalog = catalog._catalog diff --git a/src/daft-sql/src/catalog.rs b/src/daft-sql/src/catalog.rs index 9064d517db..4da8ca6c8a 100644 --- a/src/daft-sql/src/catalog.rs +++ b/src/daft-sql/src/catalog.rs @@ -25,6 +25,13 @@ impl SQLCatalog { pub fn get_table(&self, name: &str) -> Option { self.tables.get(name).cloned() } + + /// Copy from another catalog, using tables from other in case of conflict + pub fn copy_from(&mut self, other: &SQLCatalog) { + for (name, plan) in other.tables.iter() { + self.tables.insert(name.clone(), plan.clone()); + } + } } impl Default for SQLCatalog { diff --git a/src/daft-sql/src/python.rs b/src/daft-sql/src/python.rs index 7dd5556878..216aaba3d8 100644 --- a/src/daft-sql/src/python.rs +++ b/src/daft-sql/src/python.rs @@ -45,6 +45,11 @@ impl PyCatalog { self.catalog.register_table(name, plan); } + /// Copy from another catalog, using tables from other in case of conflict + pub fn copy_from(&mut self, other: &PyCatalog) { + self.catalog.copy_from(&other.catalog); + } + /// __str__ to print the catalog's tables fn __str__(&self) -> String { format!("{:?}", self.catalog) diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index dd7ac1fc54..12384e34e0 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -4,6 +4,7 @@ import pytest import daft +from daft.exceptions import DaftCoreException from daft.sql.sql import SQLCatalog from tests.assets import TPCH_QUERIES @@ -149,3 +150,52 @@ def test_sql_count_star(): actual = df2.collect().to_pydict() expected = df.agg(daft.col("b").count()).collect().to_pydict() assert actual == expected + + +GLOBAL_DF = daft.from_pydict({"n": [1, 2, 3]}) + + +def test_sql_function_sees_caller_tables(): + # sees the globals + df = daft.sql("SELECT * FROM GLOBAL_DF") + assert df.collect().to_pydict() == GLOBAL_DF.collect().to_pydict() + # sees the locals + df_copy = daft.sql("SELECT * FROM df") + assert df.collect().to_pydict() == df_copy.collect().to_pydict() + + +def test_sql_function_locals_shadow_globals(): + GLOBAL_DF = None # noqa: F841 + with pytest.raises(Exception, match="Table not found"): + daft.sql("SELECT * FROM GLOBAL_DF") + + +def test_sql_function_globals_are_added_to_catalog(): + df = daft.from_pydict({"n": [1], "x": [2]}) + res = daft.sql("SELECT * FROM GLOBAL_DF g JOIN df d USING (n)", catalog=SQLCatalog({"df": df})) + joined = GLOBAL_DF.join(df, on="n") + assert res.collect().to_pydict() == joined.collect().to_pydict() + + +def test_sql_function_catalog_is_final(): + df = daft.from_pydict({"a": [1]}) + # sanity check to ensure validity of below test + assert df.collect().to_pydict() != GLOBAL_DF.collect().to_pydict() + res = daft.sql("SELECT * FROM GLOBAL_DF", catalog=SQLCatalog({"GLOBAL_DF": df})) + assert res.collect().to_pydict() == df.collect().to_pydict() + + +def test_sql_function_register_globals(): + with pytest.raises(Exception, match="Table not found"): + daft.sql("SELECT * FROM GLOBAL_DF", SQLCatalog({}), register_globals=False) + + +def test_sql_function_requires_catalog_or_globals(): + with pytest.raises(Exception, match="Must supply a catalog"): + daft.sql("SELECT * FROM GLOBAL_DF", register_globals=False) + + +def test_sql_function_raises_when_cant_get_frame(monkeypatch): + monkeypatch.setattr("inspect.currentframe", lambda: None) + with pytest.raises(DaftCoreException, match="Cannot get caller environment"): + daft.sql("SELECT * FROM df") From d5b9a95a27fac14a916f3bbfffd8db0fdaaa8068 Mon Sep 17 00:00:00 2001 From: michaelvay <51515221+michaelvay@users.noreply.github.com> Date: Mon, 23 Sep 2024 23:51:13 +0300 Subject: [PATCH 17/35] [FEAT] Add Sparse Tensor logical type (#2722) Closes #2494 It's still WIP, wanted to get some feedback on the draft and see if I am in the right path. Still having issues with: - When casting From COOSparseTensorArray to TensorArray keeping the values dynamically type. What I currently do is assuming the type in order to be able to iterate over the values and in insert the non zero values in the relevant indices. --- daft/daft/__init__.pyi | 4 + daft/datatype.py | 30 ++ daft/table/table.py | 2 + .../src/array/growable/logical_growable.rs | 5 + src/daft-core/src/array/growable/mod.rs | 8 + src/daft-core/src/array/ops/cast.rs | 380 +++++++++++++++++- src/daft-core/src/array/ops/mod.rs | 1 + src/daft-core/src/array/ops/repr.rs | 62 ++- src/daft-core/src/array/ops/sort.rs | 15 +- src/daft-core/src/array/ops/sparse_tensor.rs | 111 +++++ src/daft-core/src/array/ops/take.rs | 2 + src/daft-core/src/array/ops/tensor.rs | 83 ++++ src/daft-core/src/array/prelude.rs | 4 +- src/daft-core/src/datatypes/logical.rs | 6 +- src/daft-core/src/datatypes/matching.rs | 2 + src/daft-core/src/datatypes/mod.rs | 2 + src/daft-core/src/datatypes/prelude.rs | 4 +- .../src/series/array_impl/binary_ops.rs | 5 +- .../src/series/array_impl/logical_array.rs | 2 + src/daft-core/src/series/ops/downcast.rs | 13 +- src/daft-core/src/series/serdes.rs | 21 +- src/daft-schema/src/dtype.rs | 38 +- src/daft-schema/src/python/datatype.rs | 23 ++ src/daft-stats/src/column_stats/mod.rs | 2 +- src/daft-table/src/repr_html.rs | 8 + tests/benchmarks/conftest.py | 19 - tests/io/test_parquet_roundtrip.py | 18 + tests/series/test_sparse_tensor.py | 60 +++ 28 files changed, 892 insertions(+), 38 deletions(-) create mode 100644 src/daft-core/src/array/ops/sparse_tensor.rs create mode 100644 tests/series/test_sparse_tensor.py diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 2647b23c58..1a5dc99f0f 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -992,6 +992,8 @@ class PyDataType: @staticmethod def tensor(dtype: PyDataType, shape: tuple[int, ...] | None = None) -> PyDataType: ... @staticmethod + def sparse_tensor(dtype: PyDataType, shape: tuple[int, ...] | None = None) -> PyDataType: ... + @staticmethod def python() -> PyDataType: ... def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pa.DataType: ... def is_numeric(self) -> builtins.bool: ... @@ -1000,6 +1002,8 @@ class PyDataType: def is_list(self) -> builtins.bool: ... def is_tensor(self) -> builtins.bool: ... def is_fixed_shape_tensor(self) -> builtins.bool: ... + def is_sparse_tensor(self) -> builtins.bool: ... + def is_fixed_shape_sparse_tensor(self) -> builtins.bool: ... def is_map(self) -> builtins.bool: ... def is_logical(self) -> builtins.bool: ... def is_boolean(self) -> builtins.bool: ... diff --git a/daft/datatype.py b/daft/datatype.py index 6d2eb1fe6b..7d13a590e1 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -322,6 +322,30 @@ def tensor( raise ValueError("Tensor shape must be a non-empty tuple of ints, but got: ", shape) return cls._from_pydatatype(PyDataType.tensor(dtype._dtype, shape)) + @classmethod + def sparse_tensor( + cls, + dtype: DataType, + shape: tuple[int, ...] | None = None, + ) -> DataType: + """Create a SparseTensor DataType: SparseTensor arrays implemented as 'COO Sparse Tensor' representation of n-dimensional arrays of data of the provided ``dtype`` as elements, each of the provided + ``shape``. + + If a ``shape`` is given, each ndarray in the column will have this shape. + + If ``shape`` is not given, the ndarrays in the column can have different shapes. This is much more flexible, + but will result in a less compact representation and may be make some operations less efficient. + + Args: + dtype: The type of the data contained within the tensor elements. + shape: The shape of each SparseTensor in the column. This is ``None`` by default, which allows the shapes of + each tensor element to vary. + """ + if shape is not None: + if not isinstance(shape, tuple) or not shape or any(not isinstance(n, int) for n in shape): + raise ValueError("SparseTensor shape must be a non-empty tuple of ints, but got: ", shape) + return cls._from_pydatatype(PyDataType.sparse_tensor(dtype._dtype, shape)) + @classmethod def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: """Maps a PyArrow DataType to a Daft DataType""" @@ -455,6 +479,12 @@ def _is_tensor_type(self) -> builtins.bool: def _is_fixed_shape_tensor_type(self) -> builtins.bool: return self._dtype.is_fixed_shape_tensor() + def _is_sparse_tensor_type(self) -> builtins.bool: + return self._dtype.is_sparse_tensor() + + def _is_fixed_shape_sparse_tensor_type(self) -> builtins.bool: + return self._dtype.is_fixed_shape_sparse_tensor() + def _is_image_type(self) -> builtins.bool: return self._dtype.is_image() diff --git a/daft/table/table.py b/daft/table/table.py index 707ea6ec98..f90f63b8d3 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -93,6 +93,8 @@ def from_arrow(arrow_table: pa.Table) -> Table: if field.dtype == DataType.python() or field.dtype._is_tensor_type() or field.dtype._is_fixed_shape_tensor_type() + or field.dtype._is_sparse_tensor_type() + or field.dtype._is_fixed_shape_sparse_tensor_type() ] if non_native_fields: # If there are any contained Arrow types that are not natively supported, go through Table.from_pydict() diff --git a/src/daft-core/src/array/growable/logical_growable.rs b/src/daft-core/src/array/growable/logical_growable.rs index 9e0a0b4425..be95087443 100644 --- a/src/daft-core/src/array/growable/logical_growable.rs +++ b/src/daft-core/src/array/growable/logical_growable.rs @@ -77,6 +77,11 @@ impl_logical_growable!(LogicalTimeGrowable, TimeType); impl_logical_growable!(LogicalEmbeddingGrowable, EmbeddingType); impl_logical_growable!(LogicalFixedShapeImageGrowable, FixedShapeImageType); impl_logical_growable!(LogicalFixedShapeTensorGrowable, FixedShapeTensorType); +impl_logical_growable!(LogicalSparseTensorGrowable, SparseTensorType); +impl_logical_growable!( + LogicalFixedShapeSparseTensorGrowable, + FixedShapeSparseTensorType +); impl_logical_growable!(LogicalImageGrowable, ImageType); impl_logical_growable!(LogicalDecimal128Growable, Decimal128Type); impl_logical_growable!(LogicalTensorGrowable, TensorType); diff --git a/src/daft-core/src/array/growable/mod.rs b/src/daft-core/src/array/growable/mod.rs index 7db6948af6..76c8c7a2e1 100644 --- a/src/daft-core/src/array/growable/mod.rs +++ b/src/daft-core/src/array/growable/mod.rs @@ -202,6 +202,14 @@ impl_growable_array!( FixedShapeTensorArray, logical_growable::LogicalFixedShapeTensorGrowable<'a> ); +impl_growable_array!( + SparseTensorArray, + logical_growable::LogicalSparseTensorGrowable<'a> +); +impl_growable_array!( + FixedShapeSparseTensorArray, + logical_growable::LogicalFixedShapeSparseTensorGrowable<'a> +); impl_growable_array!(ImageArray, logical_growable::LogicalImageGrowable<'a>); impl_growable_array!(TensorArray, logical_growable::LogicalTensorGrowable<'a>); impl_growable_array!( diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 191354eb68..3f14b8f2f5 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -19,7 +19,6 @@ use indexmap::IndexMap; use { crate::array::pseudo_arrow::PseudoArrowArray, crate::datatypes::PythonArray, - crate::with_match_numeric_daft_types, common_arrow_ffi as ffi, ndarray::IntoDimension, num_traits::{NumCast, ToPrimitive}, @@ -33,20 +32,21 @@ use crate::{ array::{ growable::make_growable, image_array::ImageArraySidecarData, - ops::{from_arrow::FromArrow, full::FullNull}, + ops::{from_arrow::FromArrow, full::FullNull, DaftCompare}, DataArray, FixedSizeListArray, ListArray, StructArray, }, datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, LogicalArray, MapArray, TensorArray, TimeArray, - TimestampArray, + FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, LogicalArray, MapArray, + SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, DaftArrayType, DaftArrowBackedType, DaftLogicalType, DataType, Field, ImageMode, Int32Array, Int64Array, NullArray, TimeUnit, UInt64Array, Utf8Array, }, series::{IntoSeries, Series}, utils::display::display_time64, + with_match_numeric_daft_types, }; impl DataArray @@ -1392,6 +1392,83 @@ impl TensorArray { ); Ok(tensor_array.into_series()) } + DataType::SparseTensor(inner_dtype) => { + let shape_iterator = self.shape_array().into_iter(); + let data_iterator = self.data_array().into_iter(); + let validity = self.data_array().validity(); + let shape_and_data_iter = shape_iterator.zip(data_iterator); + let zero_series = Int64Array::from(("item", [0].as_slice())).into_series(); + let mut non_zero_values = Vec::new(); + let mut non_zero_indices = Vec::new(); + let mut offsets = Vec::::new(); + for (i, (shape_series, data_series)) in shape_and_data_iter.enumerate() { + let is_valid = validity.map_or(true, |v| v.get_bit(i)); + if !is_valid { + // Handle invalid row by populating dummy data. + offsets.push(1); + non_zero_values.push(Series::empty("dummy", inner_dtype.as_ref())); + non_zero_indices.push(Series::empty("dummy", &DataType::UInt64)); + continue; + } + let shape_series = shape_series.unwrap(); + let data_series = data_series.unwrap(); + let shape_array = shape_series.u64().unwrap(); + assert!( + data_series.len() + == shape_array.into_iter().flatten().product::() as usize + ); + let non_zero_mask = data_series.not_equal(&zero_series)?; + let data = data_series.filter(&non_zero_mask)?; + let indices = UInt64Array::arange("item", 0, data_series.len() as i64, 1)? + .into_series() + .filter(&non_zero_mask)?; + offsets.push(data.len()); + non_zero_values.push(data); + non_zero_indices.push(indices); + } + + let offsets: Offsets = + Offsets::try_from_iter(non_zero_values.iter().map(|s| s.len()))?; + let non_zero_values_series = + Series::concat(&non_zero_values.iter().collect::>())?; + let non_zero_indices_series = + Series::concat(&non_zero_indices.iter().collect::>())?; + let offsets_cloned = offsets.clone(); + let data_list_arr = ListArray::new( + Field::new( + "values", + DataType::List(Box::new(non_zero_values_series.data_type().clone())), + ), + non_zero_values_series, + offsets.into(), + validity.cloned(), + ); + let indices_list_arr = ListArray::new( + Field::new( + "indices", + DataType::List(Box::new(non_zero_indices_series.data_type().clone())), + ), + non_zero_indices_series, + offsets_cloned.into(), + validity.cloned(), + ); + // Shapes must be all valid to reproduce dense tensor. + let all_valid_shape_array = self.shape_array().with_validity(None)?; + let sparse_struct_array = StructArray::new( + Field::new(self.name(), dtype.to_physical()), + vec![ + data_list_arr.into_series(), + indices_list_arr.into_series(), + all_valid_shape_array.into_series(), + ], + validity.cloned(), + ); + Ok(SparseTensorArray::new( + Field::new(sparse_struct_array.name(), dtype.clone()), + sparse_struct_array, + ) + .into_series()) + } DataType::Image(mode) => { let sa = self.shape_array(); if !(0..self.len()).map(|i| sa.get(i)).all(|s| { @@ -1511,6 +1588,216 @@ impl TensorArray { } } +fn cast_sparse_to_dense_for_inner_dtype( + inner_dtype: &DataType, + n_values: usize, + non_zero_indices_array: &ListArray, + non_zero_values_array: &ListArray, + offsets: &Offsets, +) -> DaftResult> { + let item: Box = with_match_numeric_daft_types!(inner_dtype, |$T| { + let mut values = vec![0 as <$T as DaftNumericType>::Native; n_values]; + let validity = non_zero_values_array.validity(); + for i in 0..non_zero_values_array.len() { + let is_valid = validity.map_or(true, |v| v.get_bit(i)); + if !is_valid { + continue; + } + let index_series: Series = non_zero_indices_array.get(i).unwrap(); + let index_array = index_series.u64().unwrap().as_arrow(); + let values_series: Series = non_zero_values_array.get(i).unwrap(); + let values_array = values_series.downcast::<<$T as DaftDataType>::ArrayType>() + .unwrap() + .as_arrow(); + for (idx, val) in index_array.into_iter().zip(values_array.into_iter()) { + let list_start_offset = offsets.start_end(i).0; + values[list_start_offset + *idx.unwrap() as usize] = *val.unwrap(); + } + } + Box::new(arrow2::array::PrimitiveArray::from_vec(values)) + }); + Ok(item) +} + +impl SparseTensorArray { + pub fn cast(&self, dtype: &DataType) -> DaftResult { + match dtype { + DataType::Tensor(inner_dtype) => { + let non_zero_values_array = self.values_array(); + let non_zero_indices_array = self.indices_array(); + let shape_array = self.shape_array(); + let sizes_vec: Vec = shape_array + .into_iter() + .map(|shape| { + shape.map_or(0, |shape| { + let shape = shape.u64().unwrap().as_arrow(); + shape.values().clone().into_iter().product::() as usize + }) + }) + .collect(); + let offsets: Offsets = Offsets::try_from_iter(sizes_vec.iter().cloned())?; + let n_values = sizes_vec.iter().sum::(); + let validity = non_zero_indices_array.validity(); + let item = cast_sparse_to_dense_for_inner_dtype( + inner_dtype, + n_values, + non_zero_indices_array, + non_zero_values_array, + &offsets, + )?; + let list_arr = ListArray::new( + Field::new( + "data", + DataType::List(Box::new(inner_dtype.as_ref().clone())), + ), + Series::try_from(("item", item))?, + offsets.into(), + validity.cloned(), + ) + .into_series(); + let physical_type = dtype.to_physical(); + let struct_array = StructArray::new( + Field::new(self.name(), physical_type), + vec![list_arr, shape_array.clone().into_series()], + validity.cloned(), + ); + Ok( + TensorArray::new(Field::new(self.name(), dtype.clone()), struct_array) + .into_series(), + ) + } + DataType::FixedShapeSparseTensor(inner_dtype, shape) => { + let sa = self.shape_array(); + let va = self.values_array(); + let ia = self.indices_array(); + if !(0..self.len()).map(|i| sa.get(i)).all(|s| { + s.map_or(true, |s| { + s.u64() + .unwrap() + .as_arrow() + .iter() + .eq(shape.iter().map(Some)) + }) + }) { + return Err(DaftError::TypeError(format!( + "Can not cast SparseTensor array to FixedShapeSparseTensor array with type {:?}: Tensor array has shapes different than {:?};", + dtype, + shape, + ))); + }; + let values_array = + va.cast(&DataType::List(Box::new(inner_dtype.as_ref().clone())))?; + let struct_array = StructArray::new( + Field::new(self.name(), dtype.to_physical()), + vec![values_array, ia.clone().into_series()], + va.validity().cloned(), + ); + let sparse_tensor_array = FixedShapeSparseTensorArray::new( + Field::new(self.name(), dtype.clone()), + struct_array, + ); + Ok(sparse_tensor_array.into_series()) + } + _ => self.physical.cast(dtype), + } + } +} + +impl FixedShapeSparseTensorArray { + pub fn cast(&self, dtype: &DataType) -> DaftResult { + match (dtype, self.data_type()) { + ( + DataType::SparseTensor(_), + DataType::FixedShapeSparseTensor(inner_dtype, tensor_shape), + ) => { + let ndim = tensor_shape.len(); + let shapes = tensor_shape + .iter() + .cycle() + .copied() + .take(ndim * self.len()) + .collect(); + let shape_offsets = (0..=ndim * self.len()) + .step_by(ndim) + .map(|v| v as i64) + .collect::>(); + + let validity = self.physical.validity(); + + let va = self.values_array(); + let ia = self.indices_array(); + + let values_arr = + va.cast(&DataType::List(Box::new(inner_dtype.as_ref().clone())))?; + + // List -> Struct + let shape_offsets = arrow2::offset::OffsetsBuffer::try_from(shape_offsets)?; + let shapes_array = ListArray::new( + Field::new("shape", DataType::List(Box::new(DataType::UInt64))), + Series::try_from(( + "shape", + Box::new(arrow2::array::PrimitiveArray::from_vec(shapes)) + as Box, + ))?, + shape_offsets, + validity.cloned(), + ); + let physical_type = dtype.to_physical(); + let struct_array = StructArray::new( + Field::new(self.name(), physical_type), + vec![ + values_arr, + ia.clone().into_series(), + shapes_array.into_series(), + ], + validity.cloned(), + ); + Ok( + SparseTensorArray::new(Field::new(self.name(), dtype.clone()), struct_array) + .into_series(), + ) + } + ( + DataType::FixedShapeTensor(_, target_tensor_shape), + DataType::FixedShapeSparseTensor(inner_dtype, tensor_shape), + ) => { + let non_zero_values_array = self.values_array(); + let non_zero_indices_array = self.indices_array(); + let size = tensor_shape.iter().product::() as usize; + let target_size = target_tensor_shape.iter().product::() as usize; + if size != target_size { + return Err(DaftError::TypeError(format!( + "Can not cast FixedShapeSparseTensor array to FixedShapeTensor array with type {:?}: FixedShapeSparseTensor array has shapes different than {:?};", + dtype, + tensor_shape, + ))); + } + let n_values = size * non_zero_values_array.len(); + let item = cast_sparse_to_dense_for_inner_dtype( + inner_dtype, + n_values, + non_zero_indices_array, + non_zero_values_array, + &Offsets::try_from_iter(repeat(target_size).take(self.len()))?, + )?; + let validity = non_zero_values_array.validity(); + let physical = FixedSizeListArray::new( + Field::new( + self.name(), + DataType::FixedSizeList(Box::new(inner_dtype.as_ref().clone()), size), + ), + Series::try_from(("item", item))?, + validity.cloned(), + ); + let fixed_shape_tensor_array = + FixedShapeTensorArray::new(Field::new(self.name(), dtype.clone()), physical); + Ok(fixed_shape_tensor_array.into_series()) + } + (_, _) => self.physical.cast(dtype), + } + } +} + impl FixedShapeTensorArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { match (dtype, self.data_type()) { @@ -1586,7 +1873,72 @@ impl FixedShapeTensorArray { .into_series(), ) } - // NOTE(Clark): Casting to FixedShapeImage is supported by the physical array cast. + ( + DataType::FixedShapeSparseTensor(_, _), + DataType::FixedShapeTensor(inner_dtype, tensor_shape), + ) => { + let physical_arr = &self.physical; + let validity = self.physical.validity(); + let zero_series = Int64Array::from(("item", [0].as_slice())).into_series(); + let mut non_zero_values = Vec::new(); + let mut non_zero_indices = Vec::new(); + let mut offsets = Vec::::new(); + for (i, data_series) in physical_arr.into_iter().enumerate() { + let is_valid = validity.map_or(true, |v| v.get_bit(i)); + if !is_valid { + // Handle invalid row by populating dummy data. + offsets.push(1); + non_zero_values.push(Series::empty("dummy", inner_dtype.as_ref())); + non_zero_indices.push(Series::empty("dummy", &DataType::UInt64)); + continue; + } + let data_series = data_series.unwrap(); + assert!(data_series.len() == tensor_shape.iter().product::() as usize); + let non_zero_mask = data_series.not_equal(&zero_series)?; + let data = data_series.filter(&non_zero_mask)?; + let indices = UInt64Array::arange("item", 0, data_series.len() as i64, 1)? + .into_series() + .filter(&non_zero_mask)?; + offsets.push(data.len()); + non_zero_values.push(data); + non_zero_indices.push(indices); + } + let offsets: Offsets = + Offsets::try_from_iter(non_zero_values.iter().map(|s| s.len()))?; + let non_zero_values_series = + Series::concat(&non_zero_values.iter().collect::>())?; + let non_zero_indices_series = + Series::concat(&non_zero_indices.iter().collect::>())?; + let offsets_cloned = offsets.clone(); + let data_list_arr = ListArray::new( + Field::new( + "values", + DataType::List(Box::new(non_zero_values_series.data_type().clone())), + ), + non_zero_values_series, + offsets.into(), + validity.cloned(), + ); + let indices_list_arr = ListArray::new( + Field::new( + "indices", + DataType::List(Box::new(non_zero_indices_series.data_type().clone())), + ), + non_zero_indices_series, + offsets_cloned.into(), + validity.cloned(), + ); + let sparse_struct_array = StructArray::new( + Field::new(self.name(), dtype.to_physical()), + vec![data_list_arr.into_series(), indices_list_arr.into_series()], + validity.cloned(), + ); + Ok(FixedShapeSparseTensorArray::new( + Field::new(sparse_struct_array.name(), dtype.clone()), + sparse_struct_array, + ) + .into_series()) + } (_, _) => self.physical.cast(dtype), } } @@ -1801,6 +2153,24 @@ impl StructArray { .into_series(), ) } + (DataType::Struct(..), DataType::SparseTensor(..)) => { + let casted_struct_array = + self.cast(&dtype.to_physical())?.struct_().unwrap().clone(); + Ok(SparseTensorArray::new( + Field::new(self.name(), dtype.clone()), + casted_struct_array, + ) + .into_series()) + } + (DataType::Struct(..), DataType::FixedShapeSparseTensor(..)) => { + let casted_struct_array = + self.cast(&dtype.to_physical())?.struct_().unwrap().clone(); + Ok(FixedShapeSparseTensorArray::new( + Field::new(self.name(), dtype.clone()), + casted_struct_array, + ) + .into_series()) + } (DataType::Struct(..), DataType::Image(..)) => { let casted_struct_array = self.cast(&dtype.to_physical())?.struct_().unwrap().clone(); diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 97596e9dbc..d3a940f376 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -47,6 +47,7 @@ mod shift; mod sign; mod sketch_percentile; mod sort; +pub(crate) mod sparse_tensor; mod sqrt; mod struct_; mod sum; diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index b7467d70e2..8d60f697c7 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -6,7 +6,8 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, + FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, MapArray, + SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DaftNumericType, DataType, ExtensionArray, FixedSizeBinaryArray, NullArray, UInt64Array, Utf8Array, @@ -269,6 +270,47 @@ impl ImageArray { } } +impl SparseTensorArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + // Shapes are always valid, use values array validity + let is_valid = self + .values_array() + .validity() + .map_or(true, |v| v.get_bit(idx)); + let shape_element = if is_valid { + self.shape_array().get(idx) + } else { + None + }; + match shape_element { + Some(shape) => Ok(format!( + "", + shape + .downcast::() + .unwrap() + .into_iter() + .map(|dim| match dim { + None => "None".to_string(), + Some(dim) => dim.to_string(), + }) + .collect::>() + .join(", ") + )), + None => Ok("None".to_string()), + } + } +} + +impl FixedShapeSparseTensorArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + if self.physical.is_valid(idx) { + Ok("".to_string()) + } else { + Ok("None".to_string()) + } + } +} + impl FixedShapeImageArray { pub fn str_value(&self, idx: usize) -> DaftResult { if self.physical.is_valid(idx) { @@ -432,3 +474,21 @@ impl TensorArray { .replace('\n', "
") } } + +impl SparseTensorArray { + pub fn html_value(&self, idx: usize) -> String { + let str_value = self.str_value(idx).unwrap(); + html_escape::encode_text(&str_value) + .into_owned() + .replace('\n', "
") + } +} + +impl FixedShapeSparseTensorArray { + pub fn html_value(&self, idx: usize) -> String { + let str_value = self.str_value(idx).unwrap(); + html_escape::encode_text(&str_value) + .into_owned() + .replace('\n', "
") + } +} diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index ed7e818578..d28c920419 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -12,7 +12,8 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, + FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, MapArray, + SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, FixedSizeBinaryArray, Float32Array, Float64Array, NullArray, Utf8Array, @@ -678,6 +679,18 @@ impl TensorArray { } } +impl SparseTensorArray { + pub fn sort(&self, _descending: bool) -> DaftResult { + todo!("impl sort for SparseTensorArray") + } +} + +impl FixedShapeSparseTensorArray { + pub fn sort(&self, _descending: bool) -> DaftResult { + todo!("impl sort for FixedShapeSparseTensorArray") + } +} + impl FixedShapeTensorArray { pub fn sort(&self, _descending: bool) -> DaftResult { todo!("impl sort for FixedShapeTensorArray") diff --git a/src/daft-core/src/array/ops/sparse_tensor.rs b/src/daft-core/src/array/ops/sparse_tensor.rs new file mode 100644 index 0000000000..696a5996b8 --- /dev/null +++ b/src/daft-core/src/array/ops/sparse_tensor.rs @@ -0,0 +1,111 @@ +use crate::{ + array::ListArray, + datatypes::logical::{FixedShapeSparseTensorArray, SparseTensorArray}, +}; + +impl SparseTensorArray { + pub fn values_array(&self) -> &ListArray { + const VALUES_IDX: usize = 0; + let array = self.physical.children.get(VALUES_IDX).unwrap(); + array.list().unwrap() + } + + pub fn indices_array(&self) -> &ListArray { + const INDICES_IDX: usize = 1; + let array = self.physical.children.get(INDICES_IDX).unwrap(); + array.list().unwrap() + } + + pub fn shape_array(&self) -> &ListArray { + const SHAPE_IDX: usize = 2; + let array = self.physical.children.get(SHAPE_IDX).unwrap(); + array.list().unwrap() + } +} + +impl FixedShapeSparseTensorArray { + pub fn values_array(&self) -> &ListArray { + const VALUES_IDX: usize = 0; + let array = self.physical.children.get(VALUES_IDX).unwrap(); + array.list().unwrap() + } + + pub fn indices_array(&self) -> &ListArray { + const INDICES_IDX: usize = 1; + let array = self.physical.children.get(INDICES_IDX).unwrap(); + array.list().unwrap() + } +} + +#[cfg(test)] +mod tests { + use std::vec; + + use common_error::DaftResult; + + use crate::{array::prelude::*, datatypes::prelude::*, series::IntoSeries}; + + #[test] + fn test_sparse_tensor_to_fixed_shape_sparse_tensor_roundtrip() -> DaftResult<()> { + let raw_validity = vec![true, false, true]; + let validity = arrow2::bitmap::Bitmap::from(raw_validity.as_slice()); + + let values_array = ListArray::new( + Field::new("values", DataType::List(Box::new(DataType::Int64))), + Int64Array::from(( + "item", + Box::new(arrow2::array::Int64Array::from_iter( + [Some(1), Some(2), Some(0), Some(3)].iter(), + )), + )) + .into_series(), + arrow2::offset::OffsetsBuffer::::try_from(vec![0, 2, 3, 4])?, + Some(validity.clone()), + ) + .into_series(); + let indices_array = ListArray::new( + Field::new("indices", DataType::List(Box::new(DataType::UInt64))), + UInt64Array::from(( + "item", + Box::new(arrow2::array::UInt64Array::from_iter( + [Some(1), Some(2), Some(0), Some(2)].iter(), + )), + )) + .into_series(), + arrow2::offset::OffsetsBuffer::::try_from(vec![0, 2, 3, 4])?, + Some(validity.clone()), + ) + .into_series(); + + let shapes_array = ListArray::new( + Field::new("shape", DataType::List(Box::new(DataType::UInt64))), + UInt64Array::from(( + "item", + Box::new(arrow2::array::UInt64Array::from_iter( + [Some(3), Some(3), Some(3)].iter(), + )), + )) + .into_series(), + arrow2::offset::OffsetsBuffer::::try_from(vec![0, 1, 2, 3])?, + Some(validity.clone()), + ) + .into_series(); + let dtype = DataType::SparseTensor(Box::new(DataType::Int64)); + let struct_array = StructArray::new( + Field::new("tensor", dtype.to_physical()), + vec![values_array, indices_array, shapes_array], + Some(validity.clone()), + ); + let sparse_tensor_array = + SparseTensorArray::new(Field::new(struct_array.name(), dtype.clone()), struct_array); + let fixed_shape_sparse_tensor_dtype = + DataType::FixedShapeSparseTensor(Box::new(DataType::Int64), vec![3]); + let fixed_shape_sparse_tensor_array = + sparse_tensor_array.cast(&fixed_shape_sparse_tensor_dtype)?; + let roundtrip_tensor = fixed_shape_sparse_tensor_array.cast(&dtype)?; + assert!(roundtrip_tensor + .to_arrow() + .eq(&sparse_tensor_array.to_arrow())); + Ok(()) + } +} diff --git a/src/daft-core/src/array/ops/take.rs b/src/daft-core/src/array/ops/take.rs index c45ad3ea56..301a311594 100644 --- a/src/daft-core/src/array/ops/take.rs +++ b/src/daft-core/src/array/ops/take.rs @@ -69,6 +69,8 @@ impl_logicalarray_take!(EmbeddingArray); impl_logicalarray_take!(ImageArray); impl_logicalarray_take!(FixedShapeImageArray); impl_logicalarray_take!(TensorArray); +impl_logicalarray_take!(SparseTensorArray); +impl_logicalarray_take!(FixedShapeSparseTensorArray); impl_logicalarray_take!(FixedShapeTensorArray); impl_logicalarray_take!(MapArray); diff --git a/src/daft-core/src/array/ops/tensor.rs b/src/daft-core/src/array/ops/tensor.rs index 36dde9dbfd..c1cd0f13ec 100644 --- a/src/daft-core/src/array/ops/tensor.rs +++ b/src/daft-core/src/array/ops/tensor.rs @@ -13,3 +13,86 @@ impl TensorArray { array.list().unwrap() } } + +#[cfg(test)] +mod tests { + use std::vec; + + use common_error::DaftResult; + + use crate::{array::prelude::*, datatypes::prelude::*, series::IntoSeries}; + + #[test] + fn test_tensor_to_sparse_roundtrip() -> DaftResult<()> { + let raw_validity = vec![true, false, true]; + let validity = arrow2::bitmap::Bitmap::from(raw_validity.as_slice()); + + let list_array = ListArray::new( + Field::new("data", DataType::List(Box::new(DataType::Int64))), + Int64Array::from(( + "item", + Box::new(arrow2::array::Int64Array::from_iter( + [ + Some(0), + Some(1), + Some(2), + Some(100), + Some(101), + Some(102), + Some(0), + Some(0), + Some(3), + ] + .iter(), + )), + )) + .into_series(), + arrow2::offset::OffsetsBuffer::::try_from(vec![0, 3, 6, 9])?, + Some(validity.clone()), + ) + .into_series(); + let shapes_array = ListArray::new( + Field::new("shape", DataType::List(Box::new(DataType::UInt64))), + UInt64Array::from(( + "item", + Box::new(arrow2::array::UInt64Array::from_iter( + [Some(3), Some(3), Some(3)].iter(), + )), + )) + .into_series(), + arrow2::offset::OffsetsBuffer::::try_from(vec![0, 1, 2, 3])?, + Some(validity.clone()), + ) + .into_series(); + let dtype = DataType::Tensor(Box::new(DataType::Int64)); + let struct_array = StructArray::new( + Field::new("tensor", dtype.to_physical()), + vec![list_array, shapes_array], + Some(validity.clone()), + ); + let tensor_array = + TensorArray::new(Field::new(struct_array.name(), dtype.clone()), struct_array); + let sparse_tensor_dtype = DataType::SparseTensor(Box::new(DataType::Int64)); + let sparse_tensor_array = tensor_array.cast(&sparse_tensor_dtype)?; + let roundtrip_tensor = sparse_tensor_array.cast(&dtype)?; + assert!(tensor_array.to_arrow().eq(&roundtrip_tensor.to_arrow())); + Ok(()) + } + + #[test] + fn test_fixed_shape_tensor_to_fixed_shape_sparse_roundtrip() -> DaftResult<()> { + let raw_validity = vec![true, false, true]; + let validity = arrow2::bitmap::Bitmap::from(raw_validity.as_slice()); + let field = Field::new("foo", DataType::FixedSizeList(Box::new(DataType::Int64), 3)); + let flat_child = Int64Array::from(("foo", (0..9).collect::>())); + let arr = FixedSizeListArray::new(field, flat_child.into_series(), Some(validity.clone())); + let dtype = DataType::FixedShapeTensor(Box::new(DataType::Int64), vec![3]); + let tensor_array = FixedShapeTensorArray::new(Field::new("data", dtype.clone()), arr); + let sparse_tensor_dtype = + DataType::FixedShapeSparseTensor(Box::new(DataType::Int64), vec![3]); + let sparse_tensor_array = tensor_array.cast(&sparse_tensor_dtype)?; + let roundtrip_tensor = sparse_tensor_array.cast(&dtype)?; + assert!(tensor_array.to_arrow().eq(&roundtrip_tensor.to_arrow())); + Ok(()) + } +} diff --git a/src/daft-core/src/array/prelude.rs b/src/daft-core/src/array/prelude.rs index fdcc8fda72..7f3bb2ff18 100644 --- a/src/daft-core/src/array/prelude.rs +++ b/src/daft-core/src/array/prelude.rs @@ -2,8 +2,8 @@ pub use super::{DataArray, FixedSizeListArray, ListArray, StructArray}; // Import logical array types pub use crate::datatypes::logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, LogicalArray, MapArray, TensorArray, TimeArray, - TimestampArray, + FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, LogicalArray, MapArray, + SparseTensorArray, TensorArray, TimeArray, TimestampArray, }; pub use crate::{ array::ops::{ diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index d61963ac69..df48b30524 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -4,8 +4,8 @@ use common_error::DaftResult; use super::{ DaftArrayType, DaftDataType, DataArray, DataType, Decimal128Type, DurationType, EmbeddingType, - FixedShapeImageType, FixedShapeTensorType, FixedSizeListArray, ImageType, MapType, TensorType, - TimeType, TimestampType, + FixedShapeImageType, FixedShapeSparseTensorType, FixedShapeTensorType, FixedSizeListArray, + ImageType, MapType, SparseTensorType, TensorType, TimeType, TimestampType, }; use crate::{ array::{ListArray, StructArray}, @@ -172,6 +172,8 @@ pub type TimestampArray = LogicalArray; pub type TensorArray = LogicalArray; pub type EmbeddingArray = LogicalArray; pub type FixedShapeTensorArray = LogicalArray; +pub type SparseTensorArray = LogicalArray; +pub type FixedShapeSparseTensorArray = LogicalArray; pub type FixedShapeImageArray = LogicalArray; pub type MapArray = LogicalArray; diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index fb85062715..b8b8e1660f 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -40,6 +40,8 @@ macro_rules! with_match_daft_types {( FixedShapeImage(..) => __with_ty__! { FixedShapeImageType }, Tensor(..) => __with_ty__! { TensorType }, FixedShapeTensor(..) => __with_ty__! { FixedShapeTensorType }, + SparseTensor(..) => __with_ty__! { SparseTensorType }, + FixedShapeSparseTensor(..) => __with_ty__! { FixedShapeSparseTensorType }, Decimal128(..) => __with_ty__! { Decimal128Type }, // Float16 => unimplemented!("Array for Float16 DataType not implemented"), Unknown => unimplemented!("Array for Unknown DataType not implemented"), diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 171f8fc957..174098ada9 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -203,6 +203,8 @@ impl_daft_logical_data_array_datatype!(TimeType, Unknown, Int64Type); impl_daft_logical_data_array_datatype!(DurationType, Unknown, Int64Type); impl_daft_logical_data_array_datatype!(ImageType, Unknown, StructType); impl_daft_logical_data_array_datatype!(TensorType, Unknown, StructType); +impl_daft_logical_data_array_datatype!(SparseTensorType, Unknown, StructType); +impl_daft_logical_data_array_datatype!(FixedShapeSparseTensorType, Unknown, StructType); impl_daft_logical_fixed_size_list_datatype!(EmbeddingType, Unknown); impl_daft_logical_fixed_size_list_datatype!(FixedShapeImageType, Unknown); impl_daft_logical_fixed_size_list_datatype!(FixedShapeTensorType, Unknown); diff --git a/src/daft-core/src/datatypes/prelude.rs b/src/daft-core/src/datatypes/prelude.rs index 6574bbe7d9..9b8fd4ce30 100644 --- a/src/daft-core/src/datatypes/prelude.rs +++ b/src/daft-core/src/datatypes/prelude.rs @@ -24,6 +24,6 @@ pub use super::{ }; pub use crate::datatypes::{ logical::DaftImageryType, DateType, Decimal128Type, DurationType, EmbeddingType, - FixedShapeImageType, FixedShapeTensorType, ImageType, MapType, TensorType, TimeType, - TimestampType, + FixedShapeImageType, FixedShapeSparseTensorType, FixedShapeTensorType, ImageType, MapType, + SparseTensorType, TensorType, TimeType, TimestampType, }; diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs index 15b2d9e273..16e1a80b8d 100644 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ b/src/daft-core/src/series/array_impl/binary_ops.rs @@ -11,7 +11,8 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, + FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, MapArray, + SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DataType, ExtensionArray, Field, FixedSizeBinaryArray, Float32Array, Float64Array, InferDataType, Int128Array, Int16Array, Int32Array, Int64Array, @@ -399,3 +400,5 @@ impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} +impl SeriesBinaryOps for ArrayWrapper {} +impl SeriesBinaryOps for ArrayWrapper {} diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index aa978cf693..bec7f069f7 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -226,4 +226,6 @@ impl_series_like_for_logical_array!(FixedShapeImageArray); impl_series_like_for_logical_array!(TensorArray); impl_series_like_for_logical_array!(EmbeddingArray); impl_series_like_for_logical_array!(FixedShapeTensorArray); +impl_series_like_for_logical_array!(SparseTensorArray); +impl_series_like_for_logical_array!(FixedShapeSparseTensorArray); impl_series_like_for_logical_array!(MapArray); diff --git a/src/daft-core/src/series/ops/downcast.rs b/src/daft-core/src/series/ops/downcast.rs index 3dd4bfe530..146f1afa8c 100644 --- a/src/daft-core/src/series/ops/downcast.rs +++ b/src/daft-core/src/series/ops/downcast.rs @@ -1,5 +1,8 @@ use common_error::DaftResult; -use logical::{EmbeddingArray, FixedShapeTensorArray, TensorArray}; +use logical::{ + EmbeddingArray, FixedShapeSparseTensorArray, FixedShapeTensorArray, SparseTensorArray, + TensorArray, +}; use self::logical::{DurationArray, ImageArray, MapArray}; use crate::{ @@ -153,4 +156,12 @@ impl Series { pub fn fixed_shape_tensor(&self) -> DaftResult<&FixedShapeTensorArray> { self.downcast() } + + pub fn sparse_tensor(&self) -> DaftResult<&SparseTensorArray> { + self.downcast() + } + + pub fn fixed_shape_sparse_tensor(&self) -> DaftResult<&FixedShapeSparseTensorArray> { + self.downcast() + } } diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index ed1bb21927..bf7e42a1e0 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -11,7 +11,8 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, + FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, MapArray, + SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, *, }, @@ -297,6 +298,24 @@ impl<'d> serde::Deserialize<'d> for Series { ) .into_series()) } + DataType::SparseTensor(..) => { + type PType = <::PhysicalType as DaftDataType>::ArrayType; + let physical = map.next_value::()?; + Ok(SparseTensorArray::new( + field, + physical.downcast::().unwrap().clone(), + ) + .into_series()) + } + DataType::FixedShapeSparseTensor(..) => { + type PType = <::PhysicalType as DaftDataType>::ArrayType; + let physical = map.next_value::()?; + Ok(FixedShapeSparseTensorArray::new( + field, + physical.downcast::().unwrap().clone(), + ) + .into_series()) + } DataType::Tensor(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 449302c185..48d9414aab 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -135,6 +135,14 @@ pub enum DataType { #[display("FixedShapeTensor[{_0}; {_1:?}]")] FixedShapeTensor(Box, Vec), + /// A logical type for sparse tensors with variable shapes. + #[display("SparseTensor({_0})")] + SparseTensor(Box), + + /// A logical type for sparse tensors with the same shape. + #[display("FixedShapeSparseTensor[{_0}; {_1:?}]")] + FixedShapeSparseTensor(Box, Vec), + #[cfg(feature = "python")] Python, @@ -249,7 +257,9 @@ impl DataType { | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) - | DataType::FixedShapeTensor(..) => { + | DataType::FixedShapeTensor(..) + | DataType::SparseTensor(..) + | DataType::FixedShapeSparseTensor(..) => { let physical = Box::new(self.to_physical()); let logical_extension = DataType::Extension( DAFT_SUPER_EXTENSION_NAME.into(), @@ -302,6 +312,15 @@ impl DataType { Box::new(*dtype.clone()), usize::try_from(shape.iter().product::()).unwrap(), ), + SparseTensor(dtype) => Struct(vec![ + Field::new("values", List(Box::new(*dtype.clone()))), + Field::new("indices", List(Box::new(DataType::UInt64))), + Field::new("shape", List(Box::new(DataType::UInt64))), + ]), + FixedShapeSparseTensor(dtype, _) => Struct(vec![ + Field::new("values", List(Box::new(*dtype.clone()))), + Field::new("indices", List(Box::new(DataType::UInt64))), + ]), _ => { assert!(self.is_physical()); self.clone() @@ -316,6 +335,8 @@ impl DataType { | DataType::List(dtype) | DataType::FixedSizeList(dtype, _) | DataType::FixedShapeTensor(dtype, _) + | DataType::SparseTensor(dtype) + | DataType::FixedShapeSparseTensor(dtype, _) | DataType::Tensor(dtype) => Some(dtype), _ => None, } @@ -351,7 +372,8 @@ impl DataType { match self { DataType::FixedSizeList(dtype, ..) | DataType::Embedding(dtype, ..) - | DataType::FixedShapeTensor(dtype, ..) => dtype.is_numeric(), + | DataType::FixedShapeTensor(dtype, ..) + | DataType::FixedShapeSparseTensor(dtype, ..) => dtype.is_numeric(), _ => false, } } @@ -404,11 +426,21 @@ impl DataType { matches!(self, DataType::Tensor(..)) } + #[inline] + pub fn is_sparse_tensor(&self) -> bool { + matches!(self, DataType::SparseTensor(..)) + } + #[inline] pub fn is_fixed_shape_tensor(&self) -> bool { matches!(self, DataType::FixedShapeTensor(..)) } + #[inline] + pub fn is_fixed_shape_sparse_tensor(&self) -> bool { + matches!(self, DataType::FixedShapeSparseTensor(..)) + } + #[inline] pub fn is_image(&self) -> bool { matches!(self, DataType::Image(..)) @@ -542,6 +574,8 @@ impl DataType { | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::FixedShapeTensor(..) + | DataType::SparseTensor(..) + | DataType::FixedShapeSparseTensor(..) | DataType::Map(..) ) } diff --git a/src/daft-schema/src/python/datatype.rs b/src/daft-schema/src/python/datatype.rs index 39f65d314f..9128eceec2 100644 --- a/src/daft-schema/src/python/datatype.rs +++ b/src/daft-schema/src/python/datatype.rs @@ -293,6 +293,21 @@ impl PyDataType { } } + #[staticmethod] + pub fn sparse_tensor(dtype: Self, shape: Option>) -> PyResult { + if !dtype.dtype.is_numeric() { + return Err(PyValueError::new_err(format!( + "The data type for a tensor column must be numeric, but got: {}", + dtype.dtype + ))); + } + let dtype = Box::new(dtype.dtype); + match shape { + Some(shape) => Ok(DataType::FixedShapeSparseTensor(dtype, shape).into()), + None => Ok(DataType::SparseTensor(dtype).into()), + } + } + #[staticmethod] pub fn python() -> PyResult { Ok(DataType::Python.into()) @@ -347,6 +362,14 @@ impl PyDataType { Ok(self.dtype.is_fixed_shape_tensor()) } + pub fn is_sparse_tensor(&self) -> PyResult { + Ok(self.dtype.is_sparse_tensor()) + } + + pub fn is_fixed_shape_sparse_tensor(&self) -> PyResult { + Ok(self.dtype.is_fixed_shape_sparse_tensor()) + } + pub fn is_map(&self) -> PyResult { Ok(self.dtype.is_map()) } diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index fb78fe3feb..e8dc82f2f8 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -71,7 +71,7 @@ impl ColumnRangeStatistics { // UNSUPPORTED TYPES: // Types that don't support comparisons and can't be used as ColumnRangeStatistics - DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, + DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, #[cfg(feature = "python")] DataType::Python => false, } diff --git a/src/daft-table/src/repr_html.rs b/src/daft-table/src/repr_html.rs index 4f81ec11cb..79ecaf063a 100644 --- a/src/daft-table/src/repr_html.rs +++ b/src/daft-table/src/repr_html.rs @@ -126,6 +126,14 @@ pub fn html_value(s: &Series, idx: usize) -> String { let arr = s.fixed_shape_tensor().unwrap(); arr.html_value(idx) } + DataType::SparseTensor(_) => { + let arr = s.sparse_tensor().unwrap(); + arr.html_value(idx) + } + DataType::FixedShapeSparseTensor(_, _) => { + let arr = s.fixed_shape_sparse_tensor().unwrap(); + arr.html_value(idx) + } #[cfg(feature = "python")] DataType::Python => { let arr = s.python().unwrap(); diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index af20021470..48e4cd43d9 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -4,28 +4,9 @@ import tempfile from collections import defaultdict -import _pytest import memray import pytest - -# Monkeypatch to use dash delimiters when showing parameter lists. -# https://github.com/pytest-dev/pytest/blob/31d0b51039fc295dfb14bfc5d2baddebe11bb746/src/_pytest/python.py#L1190 -# Related: https://github.com/pytest-dev/pytest/issues/3617 -# This allows us to perform pytest selection via the `-k` CLI flag -def id(self): - return "-".join(self._idlist) - - -setattr(_pytest.python.CallSpec2, "id", property(id)) - - -def pytest_make_parametrize_id(config, val, argname): - if isinstance(val, int): - val = f"{val:_}" - return f"{argname}:{val}" - - memray_stats = defaultdict(dict) diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 8448343abc..6904805831 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -123,6 +123,24 @@ def test_roundtrip_tensor_types(tmp_path): assert before.to_arrow() == after.to_arrow() +@pytest.mark.parametrize("fixed_shape", [True, False]) +def test_roundtrip_sparse_tensor_types(tmp_path, fixed_shape): + if fixed_shape: + expected_dtype = DataType.sparse_tensor(DataType.int64(), (2, 2)) + data = [np.array([[0, 0], [1, 0]]), None, np.array([[0, 0], [0, 0]]), np.array([[0, 1], [0, 0]])] + else: + expected_dtype = DataType.sparse_tensor(DataType.int64()) + data = [np.array([[0, 0], [1, 0]]), None, np.array([[0, 0]]), np.array([[0, 1, 0], [0, 0, 1]])] + before = daft.from_pydict({"foo": Series.from_pylist(data)}) + before = before.with_column("foo", before["foo"].cast(expected_dtype)) + before = before.concat(before) + before.write_parquet(str(tmp_path)) + after = daft.read_parquet(str(tmp_path)) + assert before.schema()["foo"].dtype == expected_dtype + assert after.schema()["foo"].dtype == expected_dtype + assert before.to_arrow() == after.to_arrow() + + # TODO: reading/writing: # 1. Embedding type # 2. Image type diff --git a/tests/series/test_sparse_tensor.py b/tests/series/test_sparse_tensor.py new file mode 100644 index 0000000000..a690d0945f --- /dev/null +++ b/tests/series/test_sparse_tensor.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import numpy as np +import pyarrow as pa +import pytest + +from daft.datatype import DataType +from daft.series import Series +from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES +from tests.utils import ANSI_ESCAPE + +ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) + + +@pytest.mark.parametrize("dtype", ARROW_INT_TYPES + ARROW_FLOAT_TYPES) +def test_sparse_tensor_roundtrip(dtype): + np_dtype = dtype.to_pandas_dtype() + data = [ + np.array([[0, 1, 0, 0], [0, 0, 0, 0]], dtype=np_dtype), + None, + np.array([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=np_dtype), + np.array([[0, 0, 0, 0], [0, 0, 4, 0]], dtype=np_dtype), + ] + s = Series.from_pylist(data, pyobj="allow") + + tensor_dtype = DataType.tensor(DataType.from_arrow_type(dtype)) + + t = s.cast(tensor_dtype) + assert t.datatype() == tensor_dtype + + # Test sparse tensor roundtrip. + sparse_tensor_dtype = DataType.sparse_tensor(dtype=DataType.from_arrow_type(dtype)) + sparse_tensor_series = t.cast(sparse_tensor_dtype) + assert sparse_tensor_series.datatype() == sparse_tensor_dtype + back = sparse_tensor_series.cast(tensor_dtype) + out = back.to_pylist() + np.testing.assert_equal(out, data) + + +def test_sparse_tensor_repr(): + arr = np.arange(np.prod((2, 2)), dtype=np.int64).reshape((2, 2)) + arrs = [arr, arr, None] + s = Series.from_pylist(arrs, pyobj="allow") + s = s.cast(DataType.sparse_tensor(dtype=DataType.from_arrow_type(pa.int64()))) + out_repr = ANSI_ESCAPE.sub("", repr(s)) + assert ( + out_repr.replace("\r", "") + == """╭─────────────────────────────╮ +│ list_series │ +│ --- │ +│ SparseTensor(Int64) │ +╞═════════════════════════════╡ +│ │ +├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ │ +├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ None │ +╰─────────────────────────────╯ +""" + ) From 7b69666f0a2fe8dc0da031b9e9bc92f4c6df29f1 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Mon, 23 Sep 2024 16:20:41 -0700 Subject: [PATCH 18/35] [CHORE] ignore vendored crates for codecov (#2895) --- codecov.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codecov.yml b/codecov.yml index c790b26836..0bb7bf4bf7 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,6 +4,8 @@ ignore: - tutorials - tools - daft/pickle +- src/arrow2 +- src/parquet2 comment: layout: reach, diff, flags, files From 5be0508a1a55151e0be438cb10cb5030eacc3294 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 24 Sep 2024 09:36:19 -0500 Subject: [PATCH 19/35] [FEAT]: add sql support for "DATE " and "DATETIME " (#2870) note: this uses standard format of `%Y-%m-%d` for DATE, and `%Y-%m-%d %H:%M:%S %z` for DATETIME. Any custom formats can still use the existing `to_date` and `to_datetime` functions. --- src/daft-sql/src/lib.rs | 2 ++ src/daft-sql/src/planner.rs | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 40ae6c57df..310c256e27 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -310,6 +310,8 @@ mod tests { #[case::to_date("select to_date(utf8, 'YYYY-MM-DD') as to_date from tbl1")] #[case::like("select utf8 like 'a' as like from tbl1")] #[case::ilike("select utf8 ilike 'a' as ilike from tbl1")] + #[case::datestring("select DATE '2021-08-01' as dt from tbl1")] + #[case::datetime("select DATETIME '2021-08-01 00:00:00' as dt from tbl1")] // #[case::to_datetime("select to_datetime(utf8, 'YYYY-MM-DD') as to_datetime from tbl1")] fn test_compiles_funcs(mut planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> { let plan = planner.plan_sql(query); diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index cf82f72743..4e2bea8f60 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -4,7 +4,7 @@ use common_error::DaftResult; use daft_core::prelude::*; use daft_dsl::{ col, - functions::utf8::{ilike, like}, + functions::utf8::{ilike, like, to_date, to_datetime}, has_agg, lit, literals_to_series, null_lit, Expr, ExprRef, LiteralValue, Operator, }; use daft_functions::numeric::{ceil::ceil, floor::floor}; @@ -626,8 +626,18 @@ impl SQLPlanner { SQLExpr::Collate { .. } => unsupported_sql_err!("COLLATE"), SQLExpr::Nested(_) => unsupported_sql_err!("NESTED"), SQLExpr::IntroducedString { .. } => unsupported_sql_err!("INTRODUCED STRING"), - - SQLExpr::TypedString { .. } => unsupported_sql_err!("TYPED STRING"), + SQLExpr::TypedString { data_type, value } => match data_type { + sqlparser::ast::DataType::Date => Ok(to_date(lit(value.as_str()), "%Y-%m-%d")), + sqlparser::ast::DataType::Timestamp(None, TimezoneInfo::None) + | sqlparser::ast::DataType::Datetime(None) => Ok(to_datetime( + lit(value.as_str()), + "%Y-%m-%d %H:%M:%S %z", + None, + )), + dtype => { + unsupported_sql_err!("TypedString with data type {:?}", dtype) + } + }, SQLExpr::MapAccess { .. } => unsupported_sql_err!("MAP ACCESS"), SQLExpr::Function(func) => self.plan_function(func), SQLExpr::Case { From 8307b6b4219b63f5e4ba0b0e9f604104d980df55 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 24 Sep 2024 12:51:51 -0500 Subject: [PATCH 20/35] [CHORE]: Move daft.sql.sql module to daft.sql (#2907) see title --- daft/__init__.py | 2 +- daft/sql/__init__.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 daft/sql/__init__.py diff --git a/daft/__init__.py b/daft/__init__.py index 153e6b8eb2..8d9c0598d8 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -88,7 +88,7 @@ def refresh_logger() -> None: read_lance, ) from daft.series import Series -from daft.sql.sql import sql, sql_expr +from daft.sql import sql, sql_expr from daft.udf import udf from daft.viz import register_viz_hook diff --git a/daft/sql/__init__.py b/daft/sql/__init__.py new file mode 100644 index 0000000000..fc7724545f --- /dev/null +++ b/daft/sql/__init__.py @@ -0,0 +1,7 @@ +from .sql import SQLCatalog, sql, sql_expr + +__all__ = [ + "SQLCatalog", + "sql", + "sql_expr", +] From c66e384acb50d7ab7d1271ae02ad4bd92aeb9971 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 24 Sep 2024 13:10:47 -0500 Subject: [PATCH 21/35] [CHORE]: bump sqlparser version (#2886) --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- src/daft-sql/src/planner.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8dc27f4c89..ad43bff338 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5040,9 +5040,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.49.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" +checksum = "5fe11944a61da0da3f592e19a45ebe5ab92dc14a779907ff1f08fbb797bfefc7" dependencies = [ "log", ] diff --git a/Cargo.toml b/Cargo.toml index 55403d40bb..53f9b894b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -169,7 +169,7 @@ rstest = "0.18.2" serde_json = "1.0.116" sketches-ddsketch = {version = "0.2.2", features = ["use_serde"]} snafu = {version = "0.7.4", features = ["futures"]} -sqlparser = "0.49.0" +sqlparser = "0.51.0" sysinfo = "0.30.12" test-log = "0.2.16" thiserror = "1.0.63" diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 4e2bea8f60..af111b1532 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -871,7 +871,7 @@ impl SQLPlanner { // --------------------------------- // struct // --------------------------------- - SQLDataType::Struct(fields) => { + SQLDataType::Struct(fields, _) => { let fields = fields .iter() .enumerate() From e3dd6710c90e64ba62c8d2c91844b7aa71c437b8 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 24 Sep 2024 13:41:11 -0500 Subject: [PATCH 22/35] [FEAT]: add partitioning_* functions to sql (#2869) --- src/daft-sql/src/modules/partitioning.rs | 99 +++++++++++++++++++++++- tests/sql/test_partitioning_exprs.py | 45 +++++++++++ 2 files changed, 140 insertions(+), 4 deletions(-) create mode 100644 tests/sql/test_partitioning_exprs.py diff --git a/src/daft-sql/src/modules/partitioning.rs b/src/daft-sql/src/modules/partitioning.rs index b357ac810f..589c298e2f 100644 --- a/src/daft-sql/src/modules/partitioning.rs +++ b/src/daft-sql/src/modules/partitioning.rs @@ -1,11 +1,102 @@ +use daft_dsl::functions::partitioning::{self, PartitioningExpr}; + use super::SQLModule; -use crate::functions::SQLFunctions; +use crate::{ + ensure, + functions::{SQLFunction, SQLFunctions}, +}; pub struct SQLModulePartitioning; impl SQLModule for SQLModulePartitioning { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::Partitioning as f; - // TODO + fn register(parent: &mut SQLFunctions) { + parent.add_fn("partitioning_years", PartitioningExpr::Years); + parent.add_fn("partitioning_months", PartitioningExpr::Months); + parent.add_fn("partitioning_days", PartitioningExpr::Days); + parent.add_fn("partitioning_hours", PartitioningExpr::Hours); + parent.add_fn( + "partitioning_iceberg_bucket", + PartitioningExpr::IcebergBucket(0), + ); + parent.add_fn( + "partitioning_iceberg_truncate", + PartitioningExpr::IcebergTruncate(0), + ); + } +} + +impl SQLFunction for PartitioningExpr { + fn to_expr( + &self, + args: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match self { + PartitioningExpr::Years => { + partitioning_helper(args, planner, "years", partitioning::years) + } + PartitioningExpr::Months => { + partitioning_helper(args, planner, "months", partitioning::months) + } + PartitioningExpr::Days => { + partitioning_helper(args, planner, "days", partitioning::days) + } + PartitioningExpr::Hours => { + partitioning_helper(args, planner, "hours", partitioning::hours) + } + PartitioningExpr::IcebergBucket(_) => { + ensure!(args.len() == 2, "iceberg_bucket takes exactly 2 arguments"); + let input = planner.plan_function_arg(&args[0])?; + let n = planner + .plan_function_arg(&args[1])? + .as_literal() + .and_then(|l| l.as_i64()) + .ok_or_else(|| { + crate::error::PlannerError::unsupported_sql( + "Expected integer literal".to_string(), + ) + }) + .and_then(|n| { + if n > i32::MAX as i64 { + Err(crate::error::PlannerError::unsupported_sql( + "Integer literal too large".to_string(), + )) + } else { + Ok(n as i32) + } + })?; + + Ok(partitioning::iceberg_bucket(input, n)) + } + PartitioningExpr::IcebergTruncate(_) => { + ensure!( + args.len() == 2, + "iceberg_truncate takes exactly 2 arguments" + ); + let input = planner.plan_function_arg(&args[0])?; + let w = planner + .plan_function_arg(&args[1])? + .as_literal() + .and_then(|l| l.as_i64()) + .ok_or_else(|| { + crate::error::PlannerError::unsupported_sql( + "Expected integer literal".to_string(), + ) + })?; + + Ok(partitioning::iceberg_truncate(input, w)) + } + } } } + +fn partitioning_helper daft_dsl::ExprRef>( + args: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + method_name: &str, + f: F, +) -> crate::error::SQLPlannerResult { + ensure!(args.len() == 1, "{} takes exactly 1 argument", method_name); + let args = planner.plan_function_arg(&args[0])?; + Ok(f(args)) +} diff --git a/tests/sql/test_partitioning_exprs.py b/tests/sql/test_partitioning_exprs.py new file mode 100644 index 0000000000..04bd3d1447 --- /dev/null +++ b/tests/sql/test_partitioning_exprs.py @@ -0,0 +1,45 @@ +from datetime import datetime + +import daft +from daft.sql.sql import SQLCatalog + + +def test_partitioning_exprs(): + df = daft.from_pydict( + { + "id": [1, 2, 3, 4, 5], + "date": [ + datetime(2024, 1, 1), + datetime(2024, 2, 1), + datetime(2024, 3, 1), + datetime(2024, 4, 1), + datetime(2024, 5, 1), + ], + } + ) + catalog = SQLCatalog({"test": df}) + expected = ( + df.select( + daft.col("date").partitioning.days().alias("date_days"), + daft.col("date").partitioning.hours().alias("date_hours"), + daft.col("date").partitioning.months().alias("date_months"), + daft.col("date").partitioning.years().alias("date_years"), + daft.col("id").partitioning.iceberg_bucket(10).alias("id_bucket"), + daft.col("id").partitioning.iceberg_truncate(10).alias("id_truncate"), + ) + .collect() + .to_pydict() + ) + sql = """ + SELECT + partitioning_days(date) AS date_days, + partitioning_hours(date) AS date_hours, + partitioning_months(date) AS date_months, + partitioning_years(date) AS date_years, + partitioning_iceberg_bucket(id, 10) AS id_bucket, + partitioning_iceberg_truncate(id, 10) AS id_truncate + FROM test + """ + actual = daft.sql(sql, catalog).collect().to_pydict() + + assert actual == expected From 02b30be54e23adb51c12aee4e2719a719239fc10 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 24 Sep 2024 13:54:40 -0700 Subject: [PATCH 23/35] [BUG] Fix display for decimal types (#2909) # Overview - fixed decimal printout - was being printed out like `{}.{}` - changed to `Decimal(precision={}, scale={})` instead --- src/daft-schema/src/dtype.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 48d9414aab..c7418b5549 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -51,7 +51,7 @@ pub enum DataType { /// Fixed-precision decimal type. /// TODO: allow negative scale once Arrow2 allows it: https://github.com/jorgecarleitao/arrow2/issues/1518 - #[display("{_0}.{_1}")] + #[display("Decimal(precision={_0}, scale={_1})")] Decimal128(usize, usize), /// A [`i64`] representing a timestamp measured in [`TimeUnit`] with an optional timezone. From b519944fc6fea6ed92fc52196ffad99b9c6bd90a Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Tue, 24 Sep 2024 15:10:10 -0700 Subject: [PATCH 24/35] [FEAT] Delta Lake partitioned writing (#2884) Some things that I will cover in follow-up PRs: - split `table_io.py` up into multiple files - fix partitioned writes to conform to hive style (binary encoding, string escaping, etc) This should not actually be blocking since partition values in the delta log do not actually need to be encoded (Spark does not do so), just stringified. Just don't read it as a hive table lol --- Cargo.lock | 12 ++ Cargo.toml | 1 + daft/daft/__init__.pyi | 1 + daft/dataframe/dataframe.py | 43 +++-- daft/execution/execution_step.py | 2 + daft/execution/physical_plan.py | 2 + daft/execution/rust_physical_plan_shim.py | 2 + daft/iceberg/iceberg_write.py | 12 +- daft/logical/builder.py | 2 + daft/table/partitioning.py | 52 +++--- daft/table/table_io.py | 207 +++++++++++++--------- src/daft-plan/Cargo.toml | 1 + src/daft-plan/src/builder.rs | 4 + src/daft-plan/src/logical_ops/sink.rs | 3 +- src/daft-plan/src/sink_info.rs | 88 +++------ src/daft-scheduler/src/scheduler.rs | 1 + tests/io/delta_lake/conftest.py | 2 +- tests/io/delta_lake/test_table_write.py | 95 ++++++++++ 18 files changed, 340 insertions(+), 190 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ad43bff338..27e54099d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2073,6 +2073,7 @@ dependencies = [ "daft-scan", "daft-schema", "daft-table", + "derivative", "indexmap 2.5.0", "itertools 0.11.0", "log", @@ -2239,6 +2240,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_more" version = "1.0.0" diff --git a/Cargo.toml b/Cargo.toml index 53f9b894b9..39d1d17ccb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,6 +149,7 @@ bytes = "1.6.0" chrono = "0.4.38" chrono-tz = "0.8.4" comfy-table = "7.1.1" +derivative = "2.2.0" dyn-clone = "1" futures = "0.3.30" html-escape = "0.2.13" diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 1a5dc99f0f..d071a3d85e 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1720,6 +1720,7 @@ class LogicalPlanBuilder: mode: str, version: int, large_dtypes: bool, + partition_cols: list[str] | None = None, io_config: IOConfig | None = None, ) -> LogicalPlanBuilder: ... def lance_write( diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index a4e48caba2..6211423e94 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -766,6 +766,7 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> def write_deltalake( self, table: Union[str, pathlib.Path, "DataCatalogTable", "deltalake.DeltaTable"], + partition_cols: Optional[List[str]] = None, mode: Literal["append", "overwrite", "error", "ignore"] = "append", schema_mode: Optional[Literal["merge", "overwrite"]] = None, name: Optional[str] = None, @@ -783,6 +784,7 @@ def write_deltalake( Args: table (Union[str, pathlib.Path, DataCatalogTable, deltalake.DeltaTable]): Destination `Delta Lake Table `__ or table URI to write dataframe to. + partition_cols (List[str], optional): How to subpartition each partition further. If table exists, expected to match table's existing partitioning scheme, otherwise creates the table with specified partition columns. Defaults to None. mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data, `error` will raise an error if table already exists, and `ignore` will not write anything if table already exists. Defaults to "append". schema_mode (str, optional): Schema mode of the write. If set to `overwrite`, allows replacing the schema of the table when doing `mode=overwrite`. Schema mode `merge` is currently not supported. name (str, optional): User-provided identifier for this table. @@ -802,10 +804,7 @@ def write_deltalake( import deltalake import pyarrow as pa from deltalake.schema import _convert_pa_schema_to_delta - from deltalake.writer import ( - try_get_deltatable, - write_deltalake_pyarrow, - ) + from deltalake.writer import AddAction, try_get_deltatable, write_deltalake_pyarrow from packaging.version import parse from daft import from_pydict @@ -861,6 +860,13 @@ def write_deltalake( delta_schema = _convert_pa_schema_to_delta(pyarrow_schema, **large_dtypes_kwargs(large_dtypes)) if table: + if partition_cols and partition_cols != table.metadata().partition_columns: + raise ValueError( + f"Expected partition columns to match that of the existing table ({table.metadata().partition_columns}), but received: {partition_cols}" + ) + else: + partition_cols = table.metadata().partition_columns + table.update_incremental() table_schema = table.schema().to_pyarrow(as_large_types=large_dtypes) @@ -884,42 +890,45 @@ def write_deltalake( else: version = 0 + if partition_cols is not None: + for c in partition_cols: + if self.schema()[c].dtype == DataType.binary(): + raise NotImplementedError("Binary partition columns are not yet supported for Delta Lake writes") + builder = self._builder.write_deltalake( table_uri, mode, version, large_dtypes, io_config=io_config, + partition_cols=partition_cols, ) write_df = DataFrame(builder) write_df.collect() write_result = write_df.to_pydict() - assert "data_file" in write_result - data_files = write_result["data_file"] - add_action = [] + assert "add_action" in write_result + add_actions: List[AddAction] = write_result["add_action"] operations = [] paths = [] rows = [] sizes = [] - for data_file in data_files: - stats = json.loads(data_file.stats) + for add_action in add_actions: + stats = json.loads(add_action.stats) operations.append("ADD") - paths.append(data_file.path) + paths.append(add_action.path) rows.append(stats["numRecords"]) - sizes.append(data_file.size) - - add_action.append(data_file) + sizes.append(add_action.size) if table is None: write_deltalake_pyarrow( table_uri, delta_schema, - add_action, + add_actions, mode, - [], + partition_cols or [], name, description, configuration, @@ -936,7 +945,9 @@ def write_deltalake( rows.append(old_actions_dict["num_records"][i]) sizes.append(old_actions_dict["size_bytes"][i]) - table._table.create_write_transaction(add_action, mode, [], delta_schema, None, custom_metadata) + table._table.create_write_transaction( + add_actions, mode, partition_cols or [], delta_schema, None, custom_metadata + ) table.update_incremental() with_operations = from_pydict( diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 7693a0a84c..daa9afa289 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -429,6 +429,7 @@ class WriteDeltaLake(SingleOutputInstruction): base_path: str large_dtypes: bool version: int + partition_cols: list[str] | None io_config: IOConfig | None def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: @@ -456,6 +457,7 @@ def _handle_file_write(self, input: MicroPartition) -> MicroPartition: large_dtypes=self.large_dtypes, base_path=self.base_path, version=self.version, + partition_cols=self.partition_cols, io_config=self.io_config, ) diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 273ee0dc49..220731b6b5 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -147,6 +147,7 @@ def deltalake_write( base_path: str, large_dtypes: bool, version: int, + partition_cols: list[str] | None, io_config: IOConfig | None, ) -> InProgressPhysicalPlan[PartitionT]: """Write the results of `child_plan` into pyiceberg data files described by `write_info`.""" @@ -157,6 +158,7 @@ def deltalake_write( base_path=base_path, large_dtypes=large_dtypes, version=version, + partition_cols=partition_cols, io_config=io_config, ), ) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 351d27f3bb..a19a3e2ad8 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -363,6 +363,7 @@ def write_deltalake( path: str, large_dtypes: bool, version: int, + partition_cols: list[str] | None, io_config: IOConfig | None, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: return physical_plan.deltalake_write( @@ -370,6 +371,7 @@ def write_deltalake( path, large_dtypes, version, + partition_cols, io_config, ) diff --git a/daft/iceberg/iceberg_write.py b/daft/iceberg/iceberg_write.py index 8bc4d1431b..0de4c950d8 100644 --- a/daft/iceberg/iceberg_write.py +++ b/daft/iceberg/iceberg_write.py @@ -5,7 +5,7 @@ from daft import Expression, col from daft.table import MicroPartition -from daft.table.partitioning import PartitionedTable, partition_strings_to_path, partition_values_to_string +from daft.table.partitioning import PartitionedTable, partition_strings_to_path if TYPE_CHECKING: import pyarrow as pa @@ -222,13 +222,15 @@ def partitioned_table_to_iceberg_iter( partition_values = partitioned.partition_values() if partition_values: - partition_strings = partition_values_to_string(partition_values, partition_null_fallback="null").to_pylist() - partition_values_list = partition_values.to_pylist() + partition_strings = partitioned.partition_values_str() + assert partition_strings is not None - for table, part_vals, part_strs in zip(partitioned.partitions(), partition_values_list, partition_strings): + for table, part_vals, part_strs in zip( + partitioned.partitions(), partition_values.to_pylist(), partition_strings.to_pylist() + ): iceberg_part_vals = {k: to_partition_representation(v) for k, v in part_vals.items()} part_record = IcebergRecord(**iceberg_part_vals) - part_path = partition_strings_to_path(root_path, part_strs) + part_path = partition_strings_to_path(root_path, part_strs, partition_null_fallback="null") arrow_table = coerce_pyarrow_table_to_schema(table.to_arrow(), schema) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 31b347295d..7f8ed96cf2 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -304,6 +304,7 @@ def write_deltalake( version: int, large_dtypes: bool, io_config: IOConfig, + partition_cols: list[str] | None = None, ) -> LogicalPlanBuilder: columns_name = self.schema().column_names() builder = self._builder.delta_write( @@ -312,6 +313,7 @@ def write_deltalake( mode, version, large_dtypes, + partition_cols, io_config, ) return LogicalPlanBuilder(builder) diff --git a/daft/table/partitioning.py b/daft/table/partitioning.py index fac841d346..70a590cb45 100644 --- a/daft/table/partitioning.py +++ b/daft/table/partitioning.py @@ -1,34 +1,20 @@ from typing import Dict, List, Optional +from daft import Series from daft.expressions import ExpressionsProjection -from daft.series import Series from .micropartition import MicroPartition -def partition_strings_to_path(root_path: str, parts: Dict[str, str]): - postfix = "/".join(f"{key}={value}" for key, value in parts.items()) +def partition_strings_to_path( + root_path: str, parts: Dict[str, str], partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__" +) -> str: + keys = parts.keys() + values = [partition_null_fallback if value is None else value for value in parts.values()] + postfix = "/".join(f"{k}={v}" for k, v in zip(keys, values)) return f"{root_path}/{postfix}" -def partition_values_to_string( - partition_values: MicroPartition, partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__" -) -> MicroPartition: - """Convert partition values to human-readable string representation, filling nulls with `partition_null_fallback`.""" - default_part = Series.from_pylist([partition_null_fallback]) - pkey_names = partition_values.column_names() - - partition_strings = {} - - for c in pkey_names: - column = partition_values.get_column(c) - string_names = column._to_str_values() - null_filled = column.is_null().if_else(default_part, string_names) - partition_strings[c] = null_filled.to_pylist() - - return MicroPartition.from_pydict(partition_strings) - - class PartitionedTable: def __init__(self, table: MicroPartition, partition_keys: Optional[ExpressionsProjection]): self.table = table @@ -63,3 +49,27 @@ def partition_values(self) -> Optional[MicroPartition]: if self._partition_values is None: self._create_partitions() return self._partition_values + + def partition_values_str(self) -> Optional[MicroPartition]: + """ + Returns the partition values converted to human-readable strings, keeping null values as null. + + If the table is not partitioned, returns None. + """ + null_part = Series.from_pylist([None]) + partition_values = self.partition_values() + + if partition_values is None: + return None + else: + pkey_names = partition_values.column_names() + + partition_strings = {} + + for c in pkey_names: + column = partition_values.get_column(c) + string_names = column._to_str_values() + null_filled = column.is_null().if_else(null_part, string_names) + partition_strings[c] = null_filled + + return MicroPartition.from_pydict(partition_strings) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index ee366946b2..ba07fab8a4 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -5,7 +5,6 @@ import pathlib import random import time -from functools import partial from typing import IO, TYPE_CHECKING, Any, Iterator, Union from uuid import uuid4 @@ -24,7 +23,7 @@ StorageConfig, ) from daft.dependencies import pa, pacsv, pads, pajson, pq -from daft.expressions import ExpressionsProjection +from daft.expressions import ExpressionsProjection, col from daft.filesystem import ( _resolve_paths_and_filesystem, canonicalize_protocol, @@ -40,13 +39,14 @@ from daft.sql.sql_connection import SQLConnection from .micropartition import MicroPartition -from .partitioning import PartitionedTable, partition_strings_to_path, partition_values_to_string +from .partitioning import PartitionedTable, partition_strings_to_path FileInput = Union[pathlib.Path, str, IO[bytes]] if TYPE_CHECKING: from collections.abc import Callable, Generator + from deltalake.writer import AddAction from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties @@ -404,9 +404,10 @@ def partitioned_table_to_hive_iter(partitioned: PartitionedTable, root_path: str partition_values = partitioned.partition_values() if partition_values: - partition_strings = partition_values_to_string(partition_values).to_pylist() + partition_strings = partitioned.partition_values_str() + assert partition_strings is not None - for part_table, part_strs in zip(partitioned.partitions(), partition_strings): + for part_table, part_strs in zip(partitioned.partitions(), partition_strings.to_pylist()): part_path = partition_strings_to_path(root_path, part_strs) arrow_table = part_table.to_arrow() @@ -595,112 +596,160 @@ def write_iceberg( return visitors.to_metadata() +def partitioned_table_to_deltalake_iter( + partitioned: PartitionedTable, large_dtypes: bool +) -> Iterator[tuple[pa.Table, str, dict[str, str | None]]]: + """ + Iterates over partitions, yielding each partition as an Arrow table, along with their respective paths and partition values. + """ + from deltalake.schema import _convert_pa_schema_to_delta + + from daft.io._deltalake import large_dtypes_kwargs + + partition_values = partitioned.partition_values() + + if partition_values: + partition_keys = partition_values.column_names() + partition_strings = partitioned.partition_values_str() + assert partition_strings is not None + + for part_table, part_strs in zip(partitioned.partitions(), partition_strings.to_pylist()): + part_path = partition_strings_to_path("", part_strs) + arrow_table = part_table.to_arrow() + + # Remove partition keys from the table since they are already encoded as keys + arrow_table_no_pkeys = arrow_table.drop_columns(partition_keys) + + converted_schema = _convert_pa_schema_to_delta( + arrow_table_no_pkeys.schema, **large_dtypes_kwargs(large_dtypes) + ) + converted_arrow_table = arrow_table_no_pkeys.cast(converted_schema) + + yield converted_arrow_table, part_path, part_strs + else: + arrow_table = partitioned.table.to_arrow() + arrow_batch = _convert_pa_schema_to_delta(arrow_table.schema, **large_dtypes_kwargs(large_dtypes)) + converted_arrow_table = arrow_table.cast(arrow_batch) + + yield converted_arrow_table, "/", {} + + +class DeltaLakeWriteVisitors: + class FileVisitor: + def __init__(self, parent: DeltaLakeWriteVisitors, partition_values: dict[str, str | None]): + self.parent = parent + self.partition_values = partition_values + + def __call__(self, written_file): + import json + from datetime import datetime + + import deltalake + from deltalake.writer import AddAction, DeltaJSONEncoder, get_file_stats_from_metadata + from packaging.version import parse + + from daft.utils import get_arrow_version + + # added to get_file_stats_from_metadata in deltalake v0.17.4: non-optional "num_indexed_cols" and "columns_to_collect_stats" arguments + # https://github.com/delta-io/delta-rs/blob/353e08be0202c45334dcdceee65a8679f35de710/python/deltalake/writer.py#L725 + if parse(deltalake.__version__) < parse("0.17.4"): + file_stats_args = {} + else: + file_stats_args = {"num_indexed_cols": -1, "columns_to_collect_stats": None} + + stats = get_file_stats_from_metadata(written_file.metadata, **file_stats_args) + + # PyArrow added support for written_file.size in 9.0.0 + if get_arrow_version() >= (9, 0, 0): + size = written_file.size + elif self.parent.fs is not None: + size = self.parent.fs.get_file_info([written_file.path])[0].size + else: + size = 0 + + self.parent.add_actions.append( + AddAction( + written_file.path, + size, + self.partition_values, + int(datetime.now().timestamp() * 1000), + True, + json.dumps(stats, cls=DeltaJSONEncoder), + ) + ) + + def __init__(self, fs: pa.fs.FileSystem): + self.add_actions: list[AddAction] = [] + self.fs = fs + + def visitor(self, partition_values: dict[str, str | None]) -> DeltaLakeWriteVisitors.FileVisitor: + return self.FileVisitor(self, partition_values) + + def to_metadata(self) -> MicroPartition: + return MicroPartition.from_pydict({"add_action": self.add_actions}) + + def write_deltalake( - mp: MicroPartition, + table: MicroPartition, large_dtypes: bool, base_path: str, version: int, + partition_cols: list[str] | None = None, io_config: IOConfig | None = None, ): - import json - from datetime import datetime - - import deltalake - from deltalake.schema import convert_pyarrow_table - from deltalake.writer import ( - AddAction, - DeltaJSONEncoder, - DeltaStorageHandler, - get_partitions_from_path, - ) - from packaging.version import parse + from deltalake.writer import DeltaStorageHandler from pyarrow.fs import PyFileSystem - from daft.io._deltalake import large_dtypes_kwargs from daft.io.object_store_options import io_config_to_storage_options - from daft.utils import get_arrow_version protocol = get_protocol_from_path(base_path) canonicalized_protocol = canonicalize_protocol(protocol) - data_files: list[AddAction] = [] - - # added to get_file_stats_from_metadata in deltalake v0.17.4: non-optional "num_indexed_cols" and "columns_to_collect_stats" arguments - # https://github.com/delta-io/delta-rs/blob/353e08be0202c45334dcdceee65a8679f35de710/python/deltalake/writer.py#L725 - if parse(deltalake.__version__) < parse("0.17.4"): - get_file_stats_from_metadata = deltalake.writer.get_file_stats_from_metadata - else: - get_file_stats_from_metadata = partial( - deltalake.writer.get_file_stats_from_metadata, num_indexed_cols=-1, columns_to_collect_stats=None - ) - - def file_visitor(written_file: Any) -> None: - path, partition_values = get_partitions_from_path(written_file.path) - stats = get_file_stats_from_metadata(written_file.metadata) - - # PyArrow added support for written_file.size in 9.0.0 - if get_arrow_version() >= (9, 0, 0): - size = written_file.size - elif fs is not None: - size = fs.get_file_info([path])[0].size - else: - size = 0 - - data_files.append( - AddAction( - path, - size, - partition_values, - int(datetime.now().timestamp() * 1000), - True, - json.dumps(stats, cls=DeltaJSONEncoder), - ) - ) - is_local_fs = canonicalized_protocol == "file" io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config storage_options = io_config_to_storage_options(io_config, base_path) fs = PyFileSystem(DeltaStorageHandler(base_path, storage_options)) - arrow_table = mp.to_arrow() - arrow_batch = convert_pyarrow_table(arrow_table, **large_dtypes_kwargs(large_dtypes)) - execution_config = get_context().daft_execution_config target_row_group_size = execution_config.parquet_target_row_group_size inflation_factor = execution_config.parquet_inflation_factor target_file_size = execution_config.parquet_target_filesize - size_bytes = arrow_table.nbytes + format = pads.ParquetFileFormat() + opts = format.make_write_options(use_compliant_nested_type=False) - target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) - num_rows = len(arrow_table) + partition_keys = ExpressionsProjection([col(c) for c in partition_cols]) if partition_cols is not None else None + partitioned = PartitionedTable(table, partition_keys) + visitors = DeltaLakeWriteVisitors(fs) - rows_per_file = max(math.ceil(num_rows / target_num_files), 1) + for part_table, part_path, part_values in partitioned_table_to_deltalake_iter(partitioned, large_dtypes): + size_bytes = part_table.nbytes - target_row_groups = max(math.ceil(size_bytes / target_row_group_size / inflation_factor), 1) - rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) + target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) + num_rows = len(part_table) - format = pads.ParquetFileFormat() + rows_per_file = max(math.ceil(num_rows / target_num_files), 1) - opts = format.make_write_options(use_compliant_nested_type=False) + target_row_groups = max(math.ceil(size_bytes / target_row_group_size / inflation_factor), 1) + rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) - _write_tabular_arrow_table( - arrow_table=arrow_batch, - schema=None, - full_path="/", - format=format, - opts=opts, - fs=fs, - rows_per_file=rows_per_file, - rows_per_row_group=rows_per_row_group, - create_dir=is_local_fs, - file_visitor=file_visitor, - version=version, - ) + _write_tabular_arrow_table( + arrow_table=part_table, + schema=None, + full_path=part_path, + format=format, + opts=opts, + fs=fs, + rows_per_file=rows_per_file, + rows_per_row_group=rows_per_row_group, + create_dir=is_local_fs, + file_visitor=visitors.visitor(part_values), + version=version, + ) - return MicroPartition.from_pydict({"data_file": Series.from_pylist(data_files, name="data_file", pyobj="force")}) + return visitors.to_metadata() def write_lance(mp: MicroPartition, base_path: str, mode: str, io_config: IOConfig | None, kwargs: dict | None): diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index d2cd422dba..a8306394f1 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -27,6 +27,7 @@ daft-functions = {path = "../daft-functions", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} daft-schema = {path = "../daft-schema", default-features = false} daft-table = {path = "../daft-table", default-features = false} +derivative = {workspace = true} indexmap = {workspace = true} itertools = {workspace = true} log = {workspace = true} diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 0098d72405..98740a6a53 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -432,6 +432,7 @@ impl LogicalPlanBuilder { mode: String, version: i32, large_dtypes: bool, + partition_cols: Option>, io_config: Option, ) -> DaftResult { use crate::sink_info::DeltaLakeCatalogInfo; @@ -441,6 +442,7 @@ impl LogicalPlanBuilder { mode, version, large_dtypes, + partition_cols, io_config, }), catalog_columns: columns_name, @@ -752,6 +754,7 @@ impl PyLogicalPlanBuilder { mode: String, version: i32, large_dtypes: bool, + partition_cols: Option>, io_config: Option, ) -> PyResult { Ok(self @@ -762,6 +765,7 @@ impl PyLogicalPlanBuilder { mode, version, large_dtypes, + partition_cols, io_config.map(|cfg| cfg.config), )? .into()) diff --git a/src/daft-plan/src/logical_ops/sink.rs b/src/daft-plan/src/logical_ops/sink.rs index 2a23292c44..d84c654c84 100644 --- a/src/daft-plan/src/logical_ops/sink.rs +++ b/src/daft-plan/src/logical_ops/sink.rs @@ -21,6 +21,7 @@ impl Sink { pub(crate) fn try_new(input: Arc, sink_info: Arc) -> DaftResult { let schema = input.schema(); + // replace partition columns with resolved columns let sink_info = match sink_info.as_ref() { SinkInfo::OutputFileInfo(OutputFileInfo { root_dir, @@ -67,7 +68,7 @@ impl Sink { Field::new("data_file", DataType::Python), ] } - CatalogType::DeltaLake(_) => vec![Field::new("data_file", DataType::Python)], + CatalogType::DeltaLake(_) => vec![Field::new("add_action", DataType::Python)], CatalogType::Lance(_) => vec![Field::new("fragments", DataType::Python)], } } diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index b66217d8d2..02c8e05273 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -5,6 +5,7 @@ use common_io_config::IOConfig; #[cfg(feature = "python")] use common_py_serde::{deserialize_py_object, serialize_py_object}; use daft_dsl::ExprRef; +use derivative::Derivative; use itertools::Itertools; #[cfg(feature = "python")] use pyo3::PyObject; @@ -43,7 +44,8 @@ pub enum CatalogType { } #[cfg(feature = "python")] -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Derivative, Debug, Clone, Serialize, Deserialize)] +#[derivative(PartialEq, Eq, Hash)] pub struct IcebergCatalogInfo { pub table_name: String, pub table_location: String, @@ -51,41 +53,27 @@ pub struct IcebergCatalogInfo { serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" )] + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] pub partition_spec: PyObject, #[serde( serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" )] + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] pub iceberg_schema: PyObject, #[serde( serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" )] + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] pub iceberg_properties: PyObject, pub io_config: Option, } -#[cfg(feature = "python")] -impl PartialEq for IcebergCatalogInfo { - fn eq(&self, other: &Self) -> bool { - self.table_name == other.table_name - && self.table_location == other.table_location - && self.io_config == other.io_config - } -} -#[cfg(feature = "python")] -impl Eq for IcebergCatalogInfo {} - -#[cfg(feature = "python")] -impl Hash for IcebergCatalogInfo { - fn hash(&self, state: &mut H) { - self.table_name.hash(state); - self.table_location.hash(state); - self.io_config.hash(state); - } -} - #[cfg(feature = "python")] impl IcebergCatalogInfo { pub fn multiline_display(&self) -> Vec { @@ -101,40 +89,16 @@ impl IcebergCatalogInfo { } #[cfg(feature = "python")] -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct DeltaLakeCatalogInfo { pub path: String, pub mode: String, pub version: i32, pub large_dtypes: bool, + pub partition_cols: Option>, pub io_config: Option, } -#[cfg(feature = "python")] -impl PartialEq for DeltaLakeCatalogInfo { - fn eq(&self, other: &Self) -> bool { - self.path == other.path - && self.mode == other.mode - && self.version == other.version - && self.large_dtypes == other.large_dtypes - && self.io_config == other.io_config - } -} - -#[cfg(feature = "python")] -impl Eq for DeltaLakeCatalogInfo {} - -#[cfg(feature = "python")] -impl Hash for DeltaLakeCatalogInfo { - fn hash(&self, state: &mut H) { - self.path.hash(state); - self.mode.hash(state); - self.version.hash(state); - self.large_dtypes.hash(state); - self.io_config.hash(state); - } -} - #[cfg(feature = "python")] impl DeltaLakeCatalogInfo { pub fn multiline_display(&self) -> Vec { @@ -143,6 +107,12 @@ impl DeltaLakeCatalogInfo { res.push(format!("Mode = {}", self.mode)); res.push(format!("Version = {}", self.version)); res.push(format!("Large Dtypes = {}", self.large_dtypes)); + if let Some(ref partition_cols) = self.partition_cols { + res.push(format!( + "Partition cols = {}", + partition_cols.iter().map(|e| e.to_string()).join(", ") + )); + } match &self.io_config { None => res.push("IOConfig = None".to_string()), Some(io_config) => res.push(format!("IOConfig = {}", io_config)), @@ -152,7 +122,8 @@ impl DeltaLakeCatalogInfo { } #[cfg(feature = "python")] -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Derivative, Debug, Clone, Serialize, Deserialize)] +#[derivative(PartialEq, Eq, Hash)] pub struct LanceCatalogInfo { pub path: String, pub mode: String, @@ -161,28 +132,11 @@ pub struct LanceCatalogInfo { serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" )] + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] pub kwargs: PyObject, } -#[cfg(feature = "python")] -impl PartialEq for LanceCatalogInfo { - fn eq(&self, other: &Self) -> bool { - self.path == other.path && self.mode == other.mode && self.io_config == other.io_config - } -} - -#[cfg(feature = "python")] -impl Eq for LanceCatalogInfo {} - -#[cfg(feature = "python")] -impl Hash for LanceCatalogInfo { - fn hash(&self, state: &mut H) { - self.path.hash(state); - self.mode.hash(state); - self.io_config.hash(state); - } -} - #[cfg(feature = "python")] impl LanceCatalogInfo { pub fn multiline_display(&self) -> Vec { diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 2eb66781d9..709dd8ff4d 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -207,6 +207,7 @@ fn deltalake_write( &delta_lake_info.path, delta_lake_info.large_dtypes, delta_lake_info.version, + delta_lake_info.partition_cols.clone(), delta_lake_info .io_config .as_ref() diff --git a/tests/io/delta_lake/conftest.py b/tests/io/delta_lake/conftest.py index 052d5a6d74..363351eb95 100644 --- a/tests/io/delta_lake/conftest.py +++ b/tests/io/delta_lake/conftest.py @@ -34,7 +34,7 @@ def num_partitions(request) -> int: pytest.param((lambda i: i, "a"), id="int_partitioned"), pytest.param((lambda i: i * 1.5, "b"), id="float_partitioned"), pytest.param((lambda i: f"foo_{i}", "c"), id="string_partitioned"), - pytest.param((lambda i: f"foo_{i}".encode(), "d"), id="string_partitioned"), + pytest.param((lambda i: f"foo_{i}".encode(), "d"), id="binary_partitioned"), pytest.param( (lambda i: datetime.datetime(2024, 2, i + 1), "f"), id="timestamp_partitioned", diff --git a/tests/io/delta_lake/test_table_write.py b/tests/io/delta_lake/test_table_write.py index 6dbcf539fa..6519e85d0f 100644 --- a/tests/io/delta_lake/test_table_write.py +++ b/tests/io/delta_lake/test_table_write.py @@ -1,6 +1,9 @@ from __future__ import annotations +import datetime +import decimal import sys +from pathlib import Path import pyarrow as pa import pytest @@ -180,3 +183,95 @@ def test_deltalake_write_ignore(tmp_path): expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) assert df1.schema() == expected_schema assert read_delta.to_pyarrow_table() == df1.to_arrow() + + +def check_equal_both_daft_and_delta_rs(df: daft.DataFrame, path: Path, sort_order: list[tuple[str, str]]): + deltalake = pytest.importorskip("deltalake") + + arrow_df = df.to_arrow().sort_by(sort_order) + + read_daft = daft.read_deltalake(str(path)) + assert read_daft.schema() == df.schema() + assert read_daft.to_arrow().sort_by(sort_order) == arrow_df + + read_delta = deltalake.DeltaTable(str(path)) + expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) + assert df.schema() == expected_schema + assert read_delta.to_pyarrow_table().cast(expected_schema.to_pyarrow_schema()).sort_by(sort_order) == arrow_df + + +@pytest.mark.parametrize( + "partition_cols,num_partitions", + [ + (["int"], 3), + (["float"], 3), + (["str"], 3), + pytest.param(["bin"], 3, marks=pytest.mark.xfail(reason="Binary partitioning is not yet supported")), + (["bool"], 3), + (["datetime"], 3), + (["date"], 3), + (["decimal"], 3), + (["int", "float"], 4), + ], +) +def test_deltalake_write_partitioned(tmp_path, partition_cols, num_partitions): + path = tmp_path / "some_table" + df = daft.from_pydict( + { + "int": [1, 1, 2, None], + "float": [1.1, 2.2, 2.2, None], + "str": ["foo", "foo", "bar", None], + "bin": [b"foo", b"foo", b"bar", None], + "bool": [True, True, False, None], + "datetime": [ + datetime.datetime(2024, 2, 10), + datetime.datetime(2024, 2, 10), + datetime.datetime(2024, 2, 11), + None, + ], + "date": [datetime.date(2024, 2, 10), datetime.date(2024, 2, 10), datetime.date(2024, 2, 11), None], + "decimal": pa.array( + [decimal.Decimal("1111.111"), decimal.Decimal("1111.111"), decimal.Decimal("2222.222"), None], + type=pa.decimal128(7, 3), + ), + } + ) + result = df.write_deltalake(str(path), partition_cols=partition_cols) + result = result.to_pydict() + assert len(result["operation"]) == num_partitions + assert all(op == "ADD" for op in result["operation"]) + assert sum(result["rows"]) == len(df) + + sort_order = [("int", "ascending"), ("float", "ascending")] + check_equal_both_daft_and_delta_rs(df, path, sort_order) + + +def test_deltalake_write_partitioned_empty(tmp_path): + path = tmp_path / "some_table" + + df = daft.from_arrow(pa.schema([("int", pa.int64()), ("string", pa.string())]).empty_table()) + + df.write_deltalake(str(path), partition_cols=["int"]) + + check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")]) + + +def test_deltalake_write_partitioned_existing_table(tmp_path): + path = tmp_path / "some_table" + + df1 = daft.from_pydict({"int": [1], "string": ["foo"]}) + result = df1.write_deltalake(str(path), partition_cols=["int"]) + result = result.to_pydict() + assert result["operation"] == ["ADD"] + assert result["rows"] == [1] + + df2 = daft.from_pydict({"int": [1, 2], "string": ["bar", "bar"]}) + with pytest.raises(ValueError): + df2.write_deltalake(str(path), partition_cols=["string"]) + + result = df2.write_deltalake(str(path)) + result = result.to_pydict() + assert result["operation"] == ["ADD", "ADD"] + assert result["rows"] == [1, 1] + + check_equal_both_daft_and_delta_rs(df1.concat(df2), path, [("int", "ascending"), ("string", "ascending")]) From 260a2c78b8bbeb589cb627db139a63210c862b25 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 24 Sep 2024 15:52:17 -0700 Subject: [PATCH 25/35] [FEAT] UTF8 to binary coercion flag (#2893) # Overview - add string-to-binary coercion flag - cleaned up some logic and instances of non-idiomatic Rust code --- Cargo.lock | 7 + Cargo.toml | 1 + daft/daft/__init__.pyi | 3 +- daft/table/table.py | 4 +- .../src/io/parquet/read/schema/convert.rs | 32 ++-- src/arrow2/src/io/parquet/read/schema/mod.rs | 54 ++++++- src/daft-micropartition/src/micropartition.rs | 13 +- src/daft-parquet/Cargo.toml | 1 + src/daft-parquet/src/file.rs | 11 +- src/daft-parquet/src/lib.rs | 3 + src/daft-parquet/src/python.rs | 16 +- src/daft-parquet/src/read.rs | 143 ++++++++++++++++-- src/daft-parquet/src/stream_reader.rs | 4 +- src/daft-scan/src/glob.rs | 7 +- .../assets/parquet-data/invalid_utf8.parquet | Bin 0 -> 517 bytes tests/table/table_io/test_parquet.py | 21 +++ 16 files changed, 262 insertions(+), 58 deletions(-) create mode 100644 tests/assets/parquet-data/invalid_utf8.parquet diff --git a/Cargo.lock b/Cargo.lock index 27e54099d0..b666624144 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2031,6 +2031,7 @@ dependencies = [ "itertools 0.11.0", "log", "parquet2", + "path_macro", "pyo3", "rayon", "serde", @@ -3934,6 +3935,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "path_macro" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6e819bbd49d5939f682638fa54826bf1650abddcd65d000923de8ad63cc7d15" + [[package]] name = "pem" version = "3.0.4" diff --git a/Cargo.toml b/Cargo.toml index 39d1d17ccb..a0f4cc19dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -162,6 +162,7 @@ jaq-std = "1.2.0" num-derive = "0.3.3" num-traits = "0.2" once_cell = "1.19.0" +path_macro = "1.0.0" pretty_assertions = "1.4.0" rand = "^0.8" rayon = "1.10.0" diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index d071a3d85e..e2cb5e1eaa 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1,7 +1,7 @@ import builtins import datetime from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Iterator +from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal from daft.dataframe.display import MermaidOptions from daft.execution import physical_plan @@ -871,6 +871,7 @@ def read_parquet_into_pyarrow( io_config: IOConfig | None = None, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: PyTimeUnit | None = None, + string_encoding: Literal["utf-8"] | Literal["raw"] = "utf-8", file_timeout_ms: int | None = None, ): ... def read_parquet_into_pyarrow_bulk( diff --git a/daft/table/table.py b/daft/table/table.py index f90f63b8d3..9ab769b337 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from daft.arrow_utils import ensure_table from daft.daft import ( @@ -526,6 +526,7 @@ def read_parquet_into_pyarrow( io_config: IOConfig | None = None, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(), + string_encoding: Literal["utf-8"] | Literal["raw"] = "utf-8", file_timeout_ms: int | None = 900_000, # 15 minutes ) -> pa.Table: fields, metadata, columns, num_rows_read = _read_parquet_into_pyarrow( @@ -537,6 +538,7 @@ def read_parquet_into_pyarrow( io_config=io_config, multithreaded_io=multithreaded_io, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit, + string_encoding=string_encoding, file_timeout_ms=file_timeout_ms, ) schema = pa.schema(fields, metadata=metadata) diff --git a/src/arrow2/src/io/parquet/read/schema/convert.rs b/src/arrow2/src/io/parquet/read/schema/convert.rs index cf9838a2a2..1fdeccc7a1 100644 --- a/src/arrow2/src/io/parquet/read/schema/convert.rs +++ b/src/arrow2/src/io/parquet/read/schema/convert.rs @@ -7,24 +7,28 @@ use parquet2::schema::{ Repetition, }; -use crate::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; -use crate::io::parquet::read::schema::SchemaInferenceOptions; +use super::StringEncoding; +use crate::{ + datatypes::{DataType, Field, IntervalUnit, TimeUnit}, + io::parquet::read::schema::SchemaInferenceOptions, +}; /// Converts [`ParquetType`]s to a [`Field`], ignoring parquet fields that do not contain /// any physical column. #[allow(dead_code)] pub fn parquet_to_arrow_schema(fields: &[ParquetType]) -> Vec { - parquet_to_arrow_schema_with_options(fields, &None) + parquet_to_arrow_schema_with_options(fields, None) } /// Like [`parquet_to_arrow_schema`] but with configurable options which affect the behavior of schema inference pub fn parquet_to_arrow_schema_with_options( fields: &[ParquetType], - options: &Option, + options: Option, ) -> Vec { + let options = options.unwrap_or_default(); fields .iter() - .filter_map(|f| to_field(f, options.as_ref().unwrap_or(&Default::default()))) + .filter_map(|f| to_field(f, &options)) .collect::>() } @@ -145,9 +149,13 @@ fn from_int64( fn from_byte_array( logical_type: &Option, converted_type: &Option, + options: &SchemaInferenceOptions, ) -> DataType { match (logical_type, converted_type) { - (Some(PrimitiveLogicalType::String), _) => DataType::Utf8, + (Some(PrimitiveLogicalType::String), _) => match options.string_encoding { + StringEncoding::Utf8 => DataType::Utf8, + StringEncoding::Raw => DataType::Binary, + }, (Some(PrimitiveLogicalType::Json), _) => DataType::Binary, (Some(PrimitiveLogicalType::Bson), _) => DataType::Binary, (Some(PrimitiveLogicalType::Enum), _) => DataType::Binary, @@ -219,9 +227,11 @@ fn to_primitive_type_inner( PhysicalType::Int96 => DataType::Timestamp(options.int96_coerce_to_timeunit, None), PhysicalType::Float => DataType::Float32, PhysicalType::Double => DataType::Float64, - PhysicalType::ByteArray => { - from_byte_array(&primitive_type.logical_type, &primitive_type.converted_type) - } + PhysicalType::ByteArray => from_byte_array( + &primitive_type.logical_type, + &primitive_type.converted_type, + options, + ), PhysicalType::FixedLenByteArray(length) => from_fixed_len_byte_array( length, primitive_type.logical_type, @@ -440,7 +450,6 @@ mod tests { use parquet2::metadata::SchemaDescriptor; use super::*; - use crate::error::Result; #[test] @@ -1123,8 +1132,9 @@ mod tests { let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; let fields = parquet_to_arrow_schema_with_options( parquet_schema.fields(), - &Some(SchemaInferenceOptions { + Some(SchemaInferenceOptions { int96_coerce_to_timeunit: tu, + ..Default::default() }), ); assert_eq!(arrow_fields, fields); diff --git a/src/arrow2/src/io/parquet/read/schema/mod.rs b/src/arrow2/src/io/parquet/read/schema/mod.rs index 293473c233..adb27b2fd9 100644 --- a/src/arrow2/src/io/parquet/read/schema/mod.rs +++ b/src/arrow2/src/io/parquet/read/schema/mod.rs @@ -1,21 +1,55 @@ //! APIs to handle Parquet <-> Arrow schemas. -use crate::datatypes::{Schema, TimeUnit}; -use crate::error::Result; +use std::str::FromStr; + +use crate::{ + datatypes::{Schema, TimeUnit}, + error::{Error, Result}, +}; mod convert; mod metadata; pub use convert::parquet_to_arrow_schema_with_options; -pub use metadata::{apply_schema_to_fields, read_schema_from_metadata}; -pub use parquet2::metadata::{FileMetaData, KeyValue, SchemaDescriptor}; -pub use parquet2::schema::types::ParquetType; - pub(crate) use convert::*; +pub use metadata::{apply_schema_to_fields, read_schema_from_metadata}; +pub use parquet2::{ + metadata::{FileMetaData, KeyValue, SchemaDescriptor}, + schema::types::ParquetType, +}; +use serde::{Deserialize, Serialize}; use self::metadata::parse_key_value_metadata; +/// String encoding options. +/// +/// Each variant of this enum maps to a different interpretation of the underlying binary data: +/// 1. `StringEncoding::Utf8` assumes the underlying binary data is UTF-8 encoded. +/// 2. `StringEncoding::Raw` makes no assumptions about the encoding of the underlying binary data. This variant will change the logical type of the column to `DataType::Binary` in the final schema. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum StringEncoding { + Raw, + #[default] + Utf8, +} + +impl FromStr for StringEncoding { + type Err = Error; + + fn from_str(value: &str) -> Result { + match value { + "utf-8" => Ok(Self::Utf8), + "raw" => Ok(Self::Raw), + encoding => Err(Error::InvalidArgumentError(format!( + "Unrecognized encoding: {}", + encoding, + ))), + } + } +} + /// Options when inferring schemas from Parquet +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SchemaInferenceOptions { /// When inferring schemas from the Parquet INT96 timestamp type, this is the corresponding TimeUnit /// in the inferred Arrow Timestamp type. @@ -25,12 +59,16 @@ pub struct SchemaInferenceOptions { /// (e.g. TimeUnit::Milliseconds) will result in loss of precision, but support a larger range of dates /// without overflowing when parsing the data. pub int96_coerce_to_timeunit: TimeUnit, + + /// The string encoding to assume when inferring the schema from Parquet data. + pub string_encoding: StringEncoding, } impl Default for SchemaInferenceOptions { fn default() -> Self { SchemaInferenceOptions { int96_coerce_to_timeunit: TimeUnit::Nanosecond, + string_encoding: StringEncoding::default(), } } } @@ -42,13 +80,13 @@ impl Default for SchemaInferenceOptions { /// This function errors iff the key `"ARROW:schema"` exists but is not correctly encoded, /// indicating that that the file's arrow metadata was incorrectly written. pub fn infer_schema(file_metadata: &FileMetaData) -> Result { - infer_schema_with_options(file_metadata, &None) + infer_schema_with_options(file_metadata, None) } /// Like [`infer_schema`] but with configurable options which affects the behavior of inference pub fn infer_schema_with_options( file_metadata: &FileMetaData, - options: &Option, + options: Option, ) -> Result { let mut metadata = parse_key_value_metadata(file_metadata.key_value_metadata()); let fields = parquet_to_arrow_schema_with_options(file_metadata.schema().fields(), options); diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index cc8439583c..c3059626fa 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -596,9 +596,9 @@ impl MicroPartition { ( _, _, - FileFormatConfig::Parquet(ParquetSourceConfig { + &FileFormatConfig::Parquet(ParquetSourceConfig { coerce_int96_timestamp_unit, - field_id_mapping, + ref field_id_mapping, chunk_size, .. }), @@ -646,12 +646,13 @@ impl MicroPartition { if scan_task.sources.len() == 1 { 1 } else { 128 }, // Hardcoded for to 128 bulk reads cfg.multithreaded_io, &ParquetSchemaInferenceOptions { - coerce_int96_timestamp_unit: *coerce_int96_timestamp_unit, + coerce_int96_timestamp_unit, + ..Default::default() }, Some(schema.clone()), field_id_mapping.clone(), parquet_metadata, - *chunk_size, + chunk_size, ) .context(DaftCoreComputeSnafu) } @@ -1162,7 +1163,7 @@ pub(crate) fn read_parquet_into_micropartition>( let schemas = metadata .iter() .map(|m| { - let schema = infer_schema_with_options(m, &Some((*schema_infer_options).into()))?; + let schema = infer_schema_with_options(m, Some((*schema_infer_options).into()))?; let daft_schema = daft_core::prelude::Schema::try_from(&schema)?; DaftResult::Ok(Arc::new(daft_schema)) }) @@ -1186,7 +1187,7 @@ pub(crate) fn read_parquet_into_micropartition>( let schemas = metadata .iter() .map(|m| { - let schema = infer_schema_with_options(m, &Some((*schema_infer_options).into()))?; + let schema = infer_schema_with_options(m, Some((*schema_infer_options).into()))?; let daft_schema = daft_core::prelude::Schema::try_from(&schema)?; DaftResult::Ok(Arc::new(daft_schema)) }) diff --git a/src/daft-parquet/Cargo.toml b/src/daft-parquet/Cargo.toml index 3e1a4876b8..1ec75b3cb1 100644 --- a/src/daft-parquet/Cargo.toml +++ b/src/daft-parquet/Cargo.toml @@ -26,6 +26,7 @@ tokio-util = {workspace = true} [dev-dependencies] bincode = {workspace = true} +path_macro = {workspace = true} [features] python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-io/python", "daft-table/python", "daft-stats/python", "daft-dsl/python", "common-arrow-ffi/python"] diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 3b84579b6d..8ba9f4ed25 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -259,11 +259,12 @@ impl ParquetReaderBuilder { } pub fn build(self) -> super::Result { - let mut arrow_schema = - infer_schema_with_options(&self.metadata, &Some(self.schema_inference_options.into())) - .context(UnableToParseSchemaFromMetadataSnafu:: { - path: self.uri.clone(), - })?; + let options = self.schema_inference_options.into(); + let mut arrow_schema = infer_schema_with_options(&self.metadata, Some(options)).context( + UnableToParseSchemaFromMetadataSnafu { + path: self.uri.clone(), + }, + )?; if let Some(names_to_keep) = self.selected_columns { arrow_schema diff --git a/src/daft-parquet/src/lib.rs b/src/daft-parquet/src/lib.rs index d1057e95f7..e907fec22e 100644 --- a/src/daft-parquet/src/lib.rs +++ b/src/daft-parquet/src/lib.rs @@ -19,6 +19,9 @@ pub use python::register_modules; #[derive(Debug, Snafu)] pub enum Error { + #[snafu(display("{source}"))] + Arrow2Error { source: arrow2::error::Error }, + #[snafu(display("{source}"))] DaftIOError { source: daft_io::Error }, diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index 930eb7e91b..2d965053c2 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -10,7 +10,9 @@ pub mod pylib { use daft_table::python::PyTable; use pyo3::{pyfunction, types::PyModule, Bound, PyResult, Python}; - use crate::read::{ArrowChunk, ParquetSchemaInferenceOptions}; + use crate::read::{ + ArrowChunk, ParquetSchemaInferenceOptions, ParquetSchemaInferenceOptionsBuilder, + }; #[allow(clippy::too_many_arguments)] #[pyfunction] pub fn read_parquet( @@ -90,6 +92,7 @@ pub mod pylib { pub fn read_parquet_into_pyarrow( py: Python, uri: &str, + string_encoding: String, columns: Option>, start_offset: Option, num_rows: Option, @@ -99,14 +102,16 @@ pub mod pylib { coerce_int96_timestamp_unit: Option, file_timeout_ms: Option, ) -> PyResult { - let read_parquet_result = py.allow_threads(|| { + let (schema, all_arrays, num_rows) = py.allow_threads(|| { let io_client = get_io_client( multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), )?; - let schema_infer_options = ParquetSchemaInferenceOptions::new( - coerce_int96_timestamp_unit.map(|tu| tu.timeunit), - ); + let schema_infer_options = ParquetSchemaInferenceOptionsBuilder { + coerce_int96_timestamp_unit, + string_encoding, + } + .build()?; crate::read::read_parquet_into_pyarrow( uri, @@ -121,7 +126,6 @@ pub mod pylib { file_timeout_ms, ) })?; - let (schema, all_arrays, num_rows) = read_parquet_result; let pyarrow = py.import_bound(pyo3::intern!(py, "pyarrow"))?; convert_pyarrow_parquet_read_result_into_py(py, schema, all_arrays, num_rows, &pyarrow) } diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 38974e2b01..eed16ae5b9 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -4,9 +4,16 @@ use std::{ time::Duration, }; -use arrow2::{bitmap::Bitmap, io::parquet::read::schema::infer_schema_with_options}; +use arrow2::{ + bitmap::Bitmap, + io::parquet::read::schema::{ + infer_schema_with_options, SchemaInferenceOptions, StringEncoding, + }, +}; use common_error::DaftResult; use daft_core::prelude::*; +#[cfg(feature = "python")] +use daft_core::python::PyTimeUnit; use daft_dsl::{optimization::get_required_columns, ExprRef}; use daft_io::{get_runtime, parse_url, IOClient, IOStatsRef, SourceType}; use daft_table::Table; @@ -22,18 +29,57 @@ use snafu::ResultExt; use crate::{file::ParquetReaderBuilder, JoinSnafu}; -#[derive(Clone, Copy, Serialize, Deserialize)] +#[cfg(feature = "python")] +#[derive(Clone)] +pub struct ParquetSchemaInferenceOptionsBuilder { + pub coerce_int96_timestamp_unit: Option, + pub string_encoding: String, +} + +#[cfg(feature = "python")] +impl ParquetSchemaInferenceOptionsBuilder { + pub fn build(self) -> crate::Result { + self.try_into() + } +} + +#[cfg(feature = "python")] +impl TryFrom for ParquetSchemaInferenceOptions { + type Error = crate::Error; + + fn try_from(value: ParquetSchemaInferenceOptionsBuilder) -> crate::Result { + Ok(ParquetSchemaInferenceOptions { + coerce_int96_timestamp_unit: value + .coerce_int96_timestamp_unit + .map_or(TimeUnit::Nanoseconds, From::from), + string_encoding: value.string_encoding.parse().context(crate::Arrow2Snafu)?, + }) + } +} + +#[cfg(feature = "python")] +impl Default for ParquetSchemaInferenceOptionsBuilder { + fn default() -> Self { + Self { + coerce_int96_timestamp_unit: Some(PyTimeUnit::nanoseconds().unwrap()), + string_encoding: "utf-8".into(), + } + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct ParquetSchemaInferenceOptions { pub coerce_int96_timestamp_unit: TimeUnit, + pub string_encoding: StringEncoding, } impl ParquetSchemaInferenceOptions { pub fn new(coerce_int96_timestamp_unit: Option) -> Self { - let default: ParquetSchemaInferenceOptions = Default::default(); let coerce_int96_timestamp_unit = - coerce_int96_timestamp_unit.unwrap_or(default.coerce_int96_timestamp_unit); + coerce_int96_timestamp_unit.unwrap_or(TimeUnit::Nanoseconds); ParquetSchemaInferenceOptions { coerce_int96_timestamp_unit, + ..Default::default() } } } @@ -42,16 +88,16 @@ impl Default for ParquetSchemaInferenceOptions { fn default() -> Self { ParquetSchemaInferenceOptions { coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, + string_encoding: StringEncoding::Utf8, } } } -impl From - for arrow2::io::parquet::read::schema::SchemaInferenceOptions -{ +impl From for SchemaInferenceOptions { fn from(value: ParquetSchemaInferenceOptions) -> Self { - arrow2::io::parquet::read::schema::SchemaInferenceOptions { + SchemaInferenceOptions { int96_coerce_to_timeunit: value.coerce_int96_timestamp_unit.to_arrow(), + string_encoding: value.string_encoding, } } } @@ -470,7 +516,7 @@ async fn read_parquet_single_into_arrow( schema_infer_options: ParquetSchemaInferenceOptions, field_id_mapping: Option>>, metadata: Option>, -) -> DaftResult<(arrow2::datatypes::SchemaRef, Vec, usize)> { +) -> DaftResult { let field_id_mapping_provided = field_id_mapping.is_some(); let (source_type, fixed_uri) = parse_url(uri)?; let (metadata, schema, all_arrays, num_rows_read) = if matches!(source_type, SourceType::File) { @@ -889,8 +935,7 @@ pub fn read_parquet_schema( let builder = builder.set_infer_schema_options(schema_inference_options); let metadata = builder.metadata; - let arrow_schema = - infer_schema_with_options(&metadata, &Some(schema_inference_options.into()))?; + let arrow_schema = infer_schema_with_options(&metadata, Some(schema_inference_options.into()))?; let schema = Schema::try_from(&arrow_schema)?; Ok((schema, metadata)) } @@ -1019,20 +1064,24 @@ pub fn read_parquet_statistics( #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{path::PathBuf, sync::Arc}; + use arrow2::{datatypes::DataType, io::parquet::read::schema::StringEncoding}; use common_error::DaftResult; use daft_io::{IOClient, IOConfig}; use futures::StreamExt; - use parquet2::metadata::FileMetaData; + use parquet2::{ + metadata::FileMetaData, + schema::types::{ParquetType, PrimitiveConvertedType, PrimitiveLogicalType}, + }; - use super::{read_parquet, read_parquet_metadata, stream_parquet}; + use super::*; const PARQUET_FILE: &str = "s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet"; const PARQUET_FILE_LOCAL: &str = "tests/assets/parquet-data/mvp.parquet"; fn get_local_parquet_path() -> String { - let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); d.push("../../"); // CARGO_MANIFEST_DIR is at src/daft-parquet d.push(PARQUET_FILE_LOCAL); d.to_str().unwrap().to_string() @@ -1116,4 +1165,68 @@ mod tests { Ok(()) })? } + + #[test] + fn test_invalid_utf8_parquet_reading() { + let parquet: Arc = path_macro::path!( + env!("CARGO_MANIFEST_DIR") + / ".." + / ".." + / "tests" + / "assets" + / "parquet-data" + / "invalid_utf8.parquet" + ) + .to_str() + .unwrap() + .into(); + let io_config = IOConfig::default(); + let io_client = Arc::new(IOClient::new(io_config.into()).unwrap()); + let runtime_handle = daft_io::get_runtime(true).unwrap(); + let file_metadata = runtime_handle + .block_on_io_pool({ + let parquet = parquet.clone(); + let io_client = io_client.clone(); + async move { read_parquet_metadata(&parquet, io_client, None, None).await } + }) + .flatten() + .unwrap(); + let primitive_type = match file_metadata.schema_descr.fields() { + [parquet_type] => match parquet_type { + ParquetType::PrimitiveType(primitive_type) => primitive_type, + ParquetType::GroupType { .. } => { + panic!("Parquet type should be primitive type, not group type") + } + }, + _ => panic!("This test parquet file should have only 1 field"), + }; + assert_eq!( + primitive_type.logical_type, + Some(PrimitiveLogicalType::String) + ); + assert_eq!( + primitive_type.converted_type, + Some(PrimitiveConvertedType::Utf8) + ); + let (schema, _, _) = read_parquet_into_pyarrow( + &parquet, + None, + None, + None, + None, + io_client, + None, + true, + ParquetSchemaInferenceOptions { + string_encoding: StringEncoding::Raw, + ..Default::default() + }, + None, + ) + .unwrap(); + match schema.fields.as_slice() { + [field] => assert_eq!(field.data_type, DataType::Binary), + _ => panic!("There should only be one field in the schema"), + }; + } } diff --git a/src/daft-parquet/src/stream_reader.rs b/src/daft-parquet/src/stream_reader.rs index dd88834aaa..178141c64d 100644 --- a/src/daft-parquet/src/stream_reader.rs +++ b/src/daft-parquet/src/stream_reader.rs @@ -229,7 +229,7 @@ pub(crate) fn local_parquet_read_into_column_iters( })?, }; - let schema = infer_schema_with_options(&metadata, &Some(schema_infer_options.into())) + let schema = infer_schema_with_options(&metadata, Some(schema_infer_options.into())) .with_context(|_| super::UnableToParseSchemaFromMetadataSnafu { path: uri.to_string(), })?; @@ -325,7 +325,7 @@ pub(crate) fn local_parquet_read_into_arrow( }; // and infer a [`Schema`] from the `metadata`. - let schema = infer_schema_with_options(&metadata, &Some(schema_infer_options.into())) + let schema = infer_schema_with_options(&metadata, Some(schema_infer_options.into())) .with_context(|_| super::UnableToParseSchemaFromMetadataSnafu { path: uri.to_string(), })?; diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 5383235c67..90621e510d 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -166,9 +166,9 @@ impl GlobScanOperator { let schema = match infer_schema { true => { let inferred_schema = match file_format_config.as_ref() { - FileFormatConfig::Parquet(ParquetSourceConfig { + &FileFormatConfig::Parquet(ParquetSourceConfig { coerce_int96_timestamp_unit, - field_id_mapping, + ref field_id_mapping, .. }) => { let io_stats = IOStatsContext::new(format!( @@ -180,7 +180,8 @@ impl GlobScanOperator { io_client.clone(), Some(io_stats), ParquetSchemaInferenceOptions { - coerce_int96_timestamp_unit: *coerce_int96_timestamp_unit, + coerce_int96_timestamp_unit, + ..Default::default() }, field_id_mapping.clone(), )?; diff --git a/tests/assets/parquet-data/invalid_utf8.parquet b/tests/assets/parquet-data/invalid_utf8.parquet new file mode 100644 index 0000000000000000000000000000000000000000..56e9a9595e191d6dd8e2151e3222cc68adc75d24 GIT binary patch literal 517 zcmb_a&r1S96n;C(W*{QcEITlVy|{ExQ%r-P!&}jctW=1?ORSr$OY6F=zo3(J>~HCh z?9EDtI`fStDkExtW$(Q;sz_emFd0S6cRWFXKZjvmv)@GV4gGFRD S&T92qT}btW3k_gEAK)7hhHwf1 literal 0 HcmV?d00001 diff --git a/tests/table/table_io/test_parquet.py b/tests/table/table_io/test_parquet.py index e568a043f6..3a17edd4ec 100644 --- a/tests/table/table_io/test_parquet.py +++ b/tests/table/table_io/test_parquet.py @@ -5,6 +5,7 @@ import os import pathlib import tempfile +from pathlib import Path import pyarrow as pa import pyarrow.parquet as papq @@ -13,6 +14,7 @@ import daft from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig from daft.datatype import DataType, TimeUnit +from daft.exceptions import DaftCoreException from daft.logical.schema import Schema from daft.runners.partitioning import TableParseParquetOptions, TableReadOptions from daft.table import ( @@ -397,3 +399,22 @@ def test_read_parquet_file_missing_column_partial_read_with_pyarrow_bulk(tmpdir) read_back = read_parquet_into_pyarrow_bulk([file_path.as_posix()], columns=["x", "MISSING"]) assert len(read_back) == 1 assert tab.drop("y") == read_back[0] # only read "x" + + +@pytest.mark.parametrize( + "parquet_path", [Path(__file__).parents[2] / "assets" / "parquet-data" / "invalid_utf8.parquet"] +) +def test_parquet_read_string_utf8_into_binary(parquet_path: Path): + import pyarrow as pa + + assert parquet_path.exists() + + with pytest.raises(DaftCoreException, match="invalid utf-8 sequence"): + read_parquet_into_pyarrow(path=parquet_path.as_posix()) + + read_back = read_parquet_into_pyarrow(path=parquet_path.as_posix(), string_encoding="raw") + schema = read_back.schema + assert len(schema) == 1 + assert schema[0].name == "invalid_string" + assert schema[0].type == pa.binary() + assert read_back["invalid_string"][0].as_py() == b"\x80\x80\x80" From 05eeffb5487e495451d098309cfa1301149e06a1 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 24 Sep 2024 18:06:58 -0500 Subject: [PATCH 26/35] [FEAT]: [SQL] struct subscript and json_query (#2891) adds nested access for structs and maps, as well as `json_query` function --------- Co-authored-by: Jay Chia <17691182+jaychia@users.noreply.github.com> --- Cargo.lock | 1 + src/daft-sql/Cargo.toml | 3 +- src/daft-sql/src/modules/json.rs | 35 +++++++++-- src/daft-sql/src/modules/map.rs | 30 ++++++++-- src/daft-sql/src/modules/structs.rs | 34 +++++++++-- src/daft-sql/src/planner.rs | 91 +++++++++++++++++++---------- tests/sql/test_nested_access.py | 44 ++++++++++++++ 7 files changed, 194 insertions(+), 44 deletions(-) create mode 100644 tests/sql/test_nested_access.py diff --git a/Cargo.lock b/Cargo.lock index b666624144..592a0793ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2173,6 +2173,7 @@ dependencies = [ "daft-core", "daft-dsl", "daft-functions", + "daft-functions-json", "daft-plan", "once_cell", "pyo3", diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index c15a71f948..86f3baa11c 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -4,6 +4,7 @@ common-error = {path = "../common/error"} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} +daft-functions-json = {path = "../daft-functions-json"} daft-plan = {path = "../daft-plan"} once_cell = {workspace = true} pyo3 = {workspace = true, optional = true} @@ -14,7 +15,7 @@ snafu.workspace = true rstest = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "daft-functions/python"] +python = ["dep:pyo3", "common-error/python", "daft-functions/python", "daft-functions-json/python"] [package] name = "daft-sql" diff --git a/src/daft-sql/src/modules/json.rs b/src/daft-sql/src/modules/json.rs index 845be622c0..f0d600daea 100644 --- a/src/daft-sql/src/modules/json.rs +++ b/src/daft-sql/src/modules/json.rs @@ -1,11 +1,38 @@ use super::SQLModule; -use crate::functions::SQLFunctions; +use crate::{ + functions::{SQLFunction, SQLFunctions}, + invalid_operation_err, +}; pub struct SQLModuleJson; impl SQLModule for SQLModuleJson { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::Json as f; - // TODO + fn register(parent: &mut SQLFunctions) { + parent.add_fn("json_query", JsonQuery); + } +} + +struct JsonQuery; + +impl SQLFunction for JsonQuery { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input, query] => { + let input = planner.plan_function_arg(input)?; + let query = planner.plan_function_arg(query)?; + if let Some(q) = query.as_literal().and_then(|l| l.as_str()) { + Ok(daft_functions_json::json_query(input, q)) + } else { + invalid_operation_err!("Expected a string literal for the query argument") + } + } + _ => invalid_operation_err!( + "invalid arguments for json_query. expected json_query(input, query)" + ), + } } } diff --git a/src/daft-sql/src/modules/map.rs b/src/daft-sql/src/modules/map.rs index 79d9a441f7..d3a328f3a4 100644 --- a/src/daft-sql/src/modules/map.rs +++ b/src/daft-sql/src/modules/map.rs @@ -1,11 +1,33 @@ use super::SQLModule; -use crate::functions::SQLFunctions; +use crate::{ + functions::{SQLFunction, SQLFunctions}, + invalid_operation_err, +}; pub struct SQLModuleMap; impl SQLModule for SQLModuleMap { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::Map as f; - // TODO + fn register(parent: &mut SQLFunctions) { + parent.add_fn("map_get", MapGet); + parent.add_fn("map_extract", MapGet); + } +} + +pub struct MapGet; + +impl SQLFunction for MapGet { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input, key] => { + let input = planner.plan_function_arg(input)?; + let key = planner.plan_function_arg(key)?; + Ok(daft_dsl::functions::map::get(input, key)) + } + _ => invalid_operation_err!("Expected 2 input args"), + } } } diff --git a/src/daft-sql/src/modules/structs.rs b/src/daft-sql/src/modules/structs.rs index 80e2ea7529..66be42d8e3 100644 --- a/src/daft-sql/src/modules/structs.rs +++ b/src/daft-sql/src/modules/structs.rs @@ -1,11 +1,37 @@ use super::SQLModule; -use crate::functions::SQLFunctions; +use crate::{ + functions::{SQLFunction, SQLFunctions}, + invalid_operation_err, +}; pub struct SQLModuleStructs; impl SQLModule for SQLModuleStructs { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::Struct as f; - // TODO + fn register(parent: &mut SQLFunctions) { + parent.add_fn("struct_get", StructGet); + parent.add_fn("struct_extract", StructGet); + } +} + +pub struct StructGet; + +impl SQLFunction for StructGet { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input, key] => { + let input = planner.plan_function_arg(input)?; + let key = planner.plan_function_arg(key)?; + if let Some(lit) = key.as_literal().and_then(|lit| lit.as_str()) { + Ok(daft_dsl::functions::struct_::get(input, lit)) + } else { + invalid_operation_err!("Expected key to be a string literal") + } + } + _ => invalid_operation_err!("Expected 2 input args"), + } } } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index af111b1532..b58f637783 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -551,9 +551,6 @@ impl SQLPlanner { .plan_compound_identifier(idents.as_slice()) .map(|e| e[0].clone()), - SQLExpr::JsonAccess { .. } => { - unsupported_sql_err!("json access") - } SQLExpr::CompositeAccess { .. } => { unsupported_sql_err!("composite access") } @@ -638,7 +635,6 @@ impl SQLPlanner { unsupported_sql_err!("TypedString with data type {:?}", dtype) } }, - SQLExpr::MapAccess { .. } => unsupported_sql_err!("MAP ACCESS"), SQLExpr::Function(func) => self.plan_function(func), SQLExpr::Case { operand, @@ -706,33 +702,7 @@ impl SQLPlanner { SQLExpr::Named { .. } => unsupported_sql_err!("NAMED"), SQLExpr::Dictionary(_) => unsupported_sql_err!("DICTIONARY"), SQLExpr::Map(_) => unsupported_sql_err!("MAP"), - SQLExpr::Subscript { expr, subscript } => match subscript.as_ref() { - Subscript::Index { index } => { - let index = self.plan_expr(index)?; - let expr = self.plan_expr(expr)?; - Ok(daft_functions::list::get(expr, index, null_lit())) - } - Subscript::Slice { - lower_bound, - upper_bound, - stride, - } => { - if stride.is_some() { - unsupported_sql_err!("stride"); - } - match (lower_bound, upper_bound) { - (Some(lower), Some(upper)) => { - let lower = self.plan_expr(lower)?; - let upper = self.plan_expr(upper)?; - let expr = self.plan_expr(expr)?; - Ok(daft_functions::list::slice(expr, lower, upper)) - } - _ => { - unsupported_sql_err!("slice with only one bound not yet supported"); - } - } - } - }, + SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()), SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"), SQLExpr::Interval(_) => unsupported_sql_err!("INTERVAL"), SQLExpr::MatchAgainst { .. } => unsupported_sql_err!("MATCH AGAINST"), @@ -741,6 +711,9 @@ impl SQLPlanner { SQLExpr::OuterJoin(_) => unsupported_sql_err!("OUTER JOIN"), SQLExpr::Prior(_) => unsupported_sql_err!("PRIOR"), SQLExpr::Lambda(_) => unsupported_sql_err!("LAMBDA"), + SQLExpr::JsonAccess { .. } | SQLExpr::MapAccess { .. } => { + unreachable!("Not reachable in our dialect, should always be parsed as subscript") + } } } @@ -936,6 +909,62 @@ impl SQLPlanner { other => unsupported_sql_err!("unary operator {:?}", other), }) } + fn plan_subscript( + &self, + expr: &sqlparser::ast::Expr, + subscript: &Subscript, + ) -> SQLPlannerResult { + match subscript { + Subscript::Index { index } => { + let expr = self.plan_expr(expr)?; + let index = self.plan_expr(index)?; + let schema = self + .current_relation + .as_ref() + .ok_or_else(|| { + PlannerError::invalid_operation("subscript without a current relation") + }) + .map(|p| p.schema())?; + let expr_field = expr.to_field(schema.as_ref())?; + match expr_field.dtype { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + Ok(daft_functions::list::get(expr, index, null_lit())) + } + DataType::Struct(_) => { + if let Some(s) = index.as_literal().and_then(|l| l.as_str()) { + Ok(daft_dsl::functions::struct_::get(expr, s)) + } else { + invalid_operation_err!("Index must be a string literal") + } + } + DataType::Map(_) => Ok(daft_dsl::functions::map::get(expr, index)), + dtype => { + invalid_operation_err!("nested access on column with type: {}", dtype) + } + } + } + Subscript::Slice { + lower_bound, + upper_bound, + stride, + } => { + if stride.is_some() { + unsupported_sql_err!("stride cannot be provided when slicing an expression"); + } + match (lower_bound, upper_bound) { + (Some(lower), Some(upper)) => { + let lower = self.plan_expr(lower)?; + let upper = self.plan_expr(upper)?; + let expr = self.plan_expr(expr)?; + Ok(daft_functions::list::slice(expr, lower, upper)) + } + _ => { + unsupported_sql_err!("slice with only one bound not yet supported"); + } + } + } + } + } } /// Checks if the SQL query is valid syntax and doesn't use unsupported features. diff --git a/tests/sql/test_nested_access.py b/tests/sql/test_nested_access.py new file mode 100644 index 0000000000..c988713d63 --- /dev/null +++ b/tests/sql/test_nested_access.py @@ -0,0 +1,44 @@ +import daft +from daft.sql.sql import SQLCatalog + + +def test_nested_access(): + df = daft.from_pydict( + { + "json": ['{"a": 1, "b": {"c": 2}}', '{"a": 3, "b": {"c": 4}}', '{"a": 5, "b": {"c": 6}}'], + "list": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "dict": [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], + } + ) + + catalog = SQLCatalog({"test": df}) + + actual = daft.sql( + """ + select + json_query(json, '.b.c') as json_b_c, + list[1] as list_1, + list[0:1] as list_slice, + dict['a'] as dict_a, + struct_get(dict, 'a') as dict_a_2, + cast(list as int[3])[1] as fsl_1, + cast(list as int[3])[0:1] as fsl_slice + from test + """, + catalog, + ).collect() + + expected = df.select( + daft.col("json").json.query(".b.c").alias("json_b_c"), + daft.col("list").list.get(1).alias("list_1"), + daft.col("list").list.slice(0, 1).alias("list_slice"), + daft.col("dict").struct.get("a").alias("dict_a"), + daft.col("dict").struct.get("a").alias("dict_a_2"), + daft.col("list").cast(daft.DataType.fixed_size_list(daft.DataType.int32(), 3)).list.get(1).alias("fsl_1"), + daft.col("list") + .cast(daft.DataType.fixed_size_list(daft.DataType.int32(), 3)) + .list.slice(0, 1) + .alias("fsl_slice"), + ).collect() + + assert actual.to_pydict() == expected.to_pydict() From 195dd003e7e7938bec3ad5dcf64dbe6e1bdf7abb Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:23:00 -0700 Subject: [PATCH 27/35] [FEAT] Add ability for RayRunner to run actor pool projects (beta feature) (#2881) --- daft/execution/physical_plan.py | 8 +- daft/runners/ray_runner.py | 158 +++++++++++++++++++++++- daft/udf.py | 1 - tests/actor_pool/test_ray_actor_pool.py | 55 +++++++++ 4 files changed, 217 insertions(+), 5 deletions(-) create mode 100644 tests/actor_pool/test_ray_actor_pool.py diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 220731b6b5..91be85e02a 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -212,7 +212,13 @@ def actor_pool_project( num_actors: int, ) -> InProgressPhysicalPlan[PartitionT]: stage_id = next(stage_id_counter) - actor_pool_name = f"ActorPool_stage{stage_id}" + + from daft.daft import extract_partial_stateful_udf_py + + stateful_udf_names = "-".join( + name for expr in projection for name in extract_partial_stateful_udf_py(expr._expr).keys() + ) + actor_pool_name = f"{stateful_udf_names}-stage={stage_id}" # Keep track of materializations of the children tasks # diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 84458815cb..7c84f7b9dc 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -38,6 +38,7 @@ IOConfig, PyDaftExecutionConfig, ResourceRequest, + extract_partial_stateful_udf_py, ) from daft.datatype import DataType from daft.execution.execution_step import ( @@ -48,7 +49,9 @@ ReduceInstruction, ScanWithTask, SingleOutputPartitionTask, + StatefulUDFProject, ) +from daft.expressions import ExpressionsProjection from daft.filesystem import glob_path_with_stats from daft.runners import runner_io from daft.runners.partitioning import ( @@ -69,7 +72,6 @@ from ray.data.block import Block as RayDatasetBlock from ray.data.dataset import Dataset as RayDataset - from daft.expressions import ExpressionsProjection from daft.logical.builder import LogicalPlanBuilder from daft.plan_scheduler import PhysicalPlanScheduler @@ -575,6 +577,8 @@ def __init__(self, max_task_backlog: int | None, use_ray_tqdm: bool) -> None: self.active_by_df: dict[str, bool] = dict() self.results_buffer_size_by_df: dict[str, int | None] = dict() + self._actor_pools: dict[str, RayRoundRobinActorPool] = {} + self.use_ray_tqdm = use_ray_tqdm def next(self, result_uuid: str) -> RayMaterializedResult | StopIteration: @@ -632,6 +636,24 @@ def stop_plan(self, result_uuid: str) -> None: del self.results_by_df[result_uuid] del self.results_buffer_size_by_df[result_uuid] + def get_actor_pool( + self, + name: str, + resource_request: ResourceRequest, + num_actors: int, + projection: ExpressionsProjection, + execution_config: PyDaftExecutionConfig, + ) -> str: + actor_pool = RayRoundRobinActorPool(name, num_actors, resource_request, projection, execution_config) + self._actor_pools[name] = actor_pool + self._actor_pools[name].setup() + return name + + def teardown_actor_pool(self, name: str) -> None: + if name in self._actor_pools: + self._actor_pools[name].teardown() + del self._actor_pools[name] + def _run_plan( self, plan_scheduler: PhysicalPlanScheduler, @@ -744,7 +766,12 @@ def place_in_queue(item): break for task in tasks_to_dispatch: - results = _build_partitions(daft_execution_config, task) + if task.actor_pool_id is None: + results = _build_partitions(daft_execution_config, task) + else: + actor_pool = self._actor_pools.get(task.actor_pool_id) + assert actor_pool is not None, "Ray actor pool must live for as long as the tasks." + results = _build_partitions_on_actor_pool(task, actor_pool) logger.debug("%s -> %s", task, results) inflight_tasks[task.id()] = task for result in results: @@ -874,6 +901,119 @@ def _build_partitions( return partitions +def _build_partitions_on_actor_pool( + task: PartitionTask[ray.ObjectRef], + actor_pool: RayRoundRobinActorPool, +) -> list[ray.ObjectRef]: + """Run a PartitionTask on an actor pool and return the resulting list of partitions.""" + [metadatas_ref, *partitions] = actor_pool.submit(task.instructions, task.partial_metadatas, task.inputs) + metadatas_accessor = PartitionMetadataAccessor(metadatas_ref) + task.set_result( + [ + RayMaterializedResult( + partition=partition, + metadatas=metadatas_accessor, + metadata_idx=i, + ) + for i, partition in enumerate(partitions) + ] + ) + return partitions + + +@ray.remote +class DaftRayActor: + def __init__(self, daft_execution_config: PyDaftExecutionConfig, uninitialized_projection: ExpressionsProjection): + self.daft_execution_config = daft_execution_config + partial_stateful_udfs = { + name: psu + for expr in uninitialized_projection + for name, psu in extract_partial_stateful_udf_py(expr._expr).items() + } + logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) + self.initialized_stateful_udfs = { + name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items() + } + + @ray.method(num_returns=2) + def run( + self, + uninitialized_projection: ExpressionsProjection, + partial_metadatas: list[PartitionMetadata], + *inputs: MicroPartition, + ) -> list[list[PartitionMetadata] | MicroPartition]: + with execution_config_ctx(config=self.daft_execution_config): + assert len(inputs) == 1, "DaftRayActor can only process single partitions" + assert len(partial_metadatas) == 1, "DaftRayActor can only process single partitions (and single metadata)" + part = inputs[0] + partial = partial_metadatas[0] + + # Bind the ExpressionsProjection to the initialized UDFs + initialized_projection = ExpressionsProjection( + [e._bind_stateful_udfs(self.initialized_stateful_udfs) for e in uninitialized_projection] + ) + new_part = part.eval_expression_list(initialized_projection) + + return [ + [PartitionMetadata.from_table(new_part).merge_with_partial(partial)], + new_part, + ] + + +class RayRoundRobinActorPool: + """Naive implementation of an ActorPool that performs round-robin task submission to the actors""" + + def __init__( + self, + pool_id: str, + num_actors: int, + resource_request: ResourceRequest, + projection: ExpressionsProjection, + execution_config: PyDaftExecutionConfig, + ): + self._actors: list[DaftRayActor] | None = None + self._task_idx = 0 + + self._execution_config = execution_config + self._num_actors = num_actors + self._resource_request_per_actor = resource_request + self._id = pool_id + self._projection = projection + + def setup(self) -> None: + self._actors = [ + DaftRayActor.options(name=f"rank={rank}-{self._id}").remote(self._execution_config, self._projection) # type: ignore + for rank in range(self._num_actors) + ] + + def teardown(self): + assert self._actors is not None, "Must have active Ray actors on teardown" + + # Delete the actors in the old pool so Ray can tear them down + old_actors = self._actors + self._actors = None + del old_actors + + def submit( + self, instruction_stack: list[Instruction], partial_metadatas: list[ray.ObjectRef], inputs: list[ray.ObjectRef] + ) -> list[ray.ObjectRef]: + assert self._actors is not None, "Must have active Ray actors during submission" + + assert ( + len(instruction_stack) == 1 + ), "RayRoundRobinActorPool can only handle single StatefulUDFProject instructions" + instruction = instruction_stack[0] + assert isinstance(instruction, StatefulUDFProject) + projection = instruction.projection + + # Determine which actor to schedule on in a round-robin fashion + idx = self._task_idx % self._num_actors + self._task_idx += 1 + actor = self._actors[idx] + + return actor.run.remote(projection, partial_metadatas, *inputs) + + class RayRunner(Runner[ray.ObjectRef]): def __init__( self, @@ -1012,7 +1152,19 @@ def run_iter_tables( def actor_pool_context( self, name: str, resource_request: ResourceRequest, num_actors: PartID, projection: ExpressionsProjection ) -> Iterator[str]: - raise NotImplementedError("actor_pool_context not yet implemented in RayRunner") + execution_config = get_context().daft_execution_config + if self.ray_client_mode: + try: + yield ray.get( + self.scheduler_actor.get_actor_pool.remote(name, resource_request, num_actors, projection) + ) + finally: + self.scheduler_actor.teardown_actor_pool.remote(name) + else: + try: + yield self.scheduler.get_actor_pool(name, resource_request, num_actors, projection, execution_config) + finally: + self.scheduler.teardown_actor_pool(name) def _collect_into_cache(self, results_iter: Iterator[RayMaterializedResult]) -> PartitionCacheEntry: result_pset = RayPartitionSet() diff --git a/daft/udf.py b/daft/udf.py index d560ddd4f5..e2afb495ff 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -430,7 +430,6 @@ def udf( num_gpus: float | None = None, memory_bytes: int | None = None, batch_size: int | None = None, - _concurrency: int | None = None, ) -> Callable[[UserProvidedPythonFunction | type], StatelessUDF | StatefulUDF]: """Decorator to convert a Python function into a UDF diff --git a/tests/actor_pool/test_ray_actor_pool.py b/tests/actor_pool/test_ray_actor_pool.py new file mode 100644 index 0000000000..239a5fd48b --- /dev/null +++ b/tests/actor_pool/test_ray_actor_pool.py @@ -0,0 +1,55 @@ +import ray + +import daft +from daft import DataType, ResourceRequest +from daft.daft import PyDaftExecutionConfig +from daft.execution.execution_step import StatefulUDFProject +from daft.expressions import ExpressionsProjection +from daft.runners.partitioning import PartialPartitionMetadata +from daft.runners.ray_runner import RayRoundRobinActorPool +from daft.table import MicroPartition + + +@daft.udf(return_dtype=DataType.int64()) +class MyStatefulUDF: + def __init__(self): + self.state = 0 + + def __call__(self, x): + self.state += 1 + return [i + self.state for i in x.to_pylist()] + + +def test_ray_actor_pool(): + projection = ExpressionsProjection([MyStatefulUDF(daft.col("x"))]) + pool = RayRoundRobinActorPool( + "my-pool", 1, ResourceRequest(num_cpus=1), projection, execution_config=PyDaftExecutionConfig.from_env() + ) + initial_partition = ray.put(MicroPartition.from_pydict({"x": [1, 1, 1]})) + ppm = PartialPartitionMetadata(num_rows=None, size_bytes=None) + instr = StatefulUDFProject(projection=projection) + pool.setup() + + result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition]) + [partial_metadata, result_data] = ray.get(result) + assert len(partial_metadata) == 1 + pm = partial_metadata[0] + assert isinstance(pm, PartialPartitionMetadata) + assert pm.num_rows == 3 + assert result_data.to_pydict() == {"x": [2, 2, 2]} + + result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition]) + [partial_metadata, result_data] = ray.get(result) + assert len(partial_metadata) == 1 + pm = partial_metadata[0] + assert isinstance(pm, PartialPartitionMetadata) + assert pm.num_rows == 3 + assert result_data.to_pydict() == {"x": [3, 3, 3]} + + result = pool.submit(instruction_stack=[instr], partial_metadatas=[ppm], inputs=[initial_partition]) + [partial_metadata, result_data] = ray.get(result) + assert len(partial_metadata) == 1 + pm = partial_metadata[0] + assert isinstance(pm, PartialPartitionMetadata) + assert pm.num_rows == 3 + assert result_data.to_pydict() == {"x": [4, 4, 4]} From d57433ad5a1d27f3efb90133ccc4fd8787f66bf0 Mon Sep 17 00:00:00 2001 From: Vignesh Date: Wed, 25 Sep 2024 09:38:41 +0530 Subject: [PATCH 28/35] [FEAT] `agg_concat` doesn't work on strings (#2847) Solves #2768 --------- Co-authored-by: Colin Ho Co-authored-by: Colin Ho --- src/daft-core/src/array/ops/concat_agg.rs | 77 +++++++++++++++++++++-- src/daft-core/src/series/ops/agg.rs | 11 +++- src/daft-dsl/src/expr.rs | 1 + tests/table/test_table_aggs.py | 50 +++++++++++++++ 4 files changed, 134 insertions(+), 5 deletions(-) diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index 6713597897..09ebb0876e 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -1,10 +1,18 @@ -use arrow2::{bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index}; +use arrow2::{ + array::{Array, Utf8Array}, + bitmap::utils::SlicesIterator, + offset::OffsetsBuffer, + types::Index, +}; use common_error::DaftResult; use super::{as_arrow::AsArrow, DaftConcatAggable}; -use crate::array::{ - growable::{make_growable, Growable}, - ListArray, +use crate::{ + array::{ + growable::{make_growable, Growable}, + DataArray, ListArray, + }, + prelude::Utf8Type, }; #[cfg(feature = "python")] @@ -146,6 +154,67 @@ impl DaftConcatAggable for ListArray { } } +impl DaftConcatAggable for DataArray { + type Output = DaftResult; + + fn concat(&self) -> Self::Output { + let new_validity = match self.validity() { + Some(validity) if validity.unset_bits() == self.len() => { + Some(arrow2::bitmap::Bitmap::from(vec![false])) + } + _ => None, + }; + + let arrow_array = self.as_arrow(); + let new_offsets = OffsetsBuffer::::try_from(vec![0, *arrow_array.offsets().last()])?; + let output = Utf8Array::new( + arrow_array.data_type().clone(), + new_offsets, + arrow_array.values().clone(), + new_validity, + ); + + let result_box = Box::new(output); + DataArray::new(self.field().clone().into(), result_box) + } + + fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output { + let arrow_array = self.as_arrow(); + let concat_per_group = if arrow_array.null_count() > 0 { + Box::new(Utf8Array::from_trusted_len_iter(groups.iter().map(|g| { + let to_concat = g + .iter() + .filter_map(|index| { + let idx = *index as usize; + arrow_array.get(idx) + }) + .collect::>(); + if to_concat.is_empty() { + None + } else { + Some(to_concat.concat()) + } + }))) + } else { + Box::new(Utf8Array::from_trusted_len_values_iter(groups.iter().map( + |g| { + g.iter() + .map(|index| { + let idx = *index as usize; + arrow_array.value(idx) + }) + .collect::() + }, + ))) + }; + + Ok(DataArray::from(( + self.field.name.as_ref(), + concat_per_group, + ))) + } +} + #[cfg(test)] mod test { use std::iter::repeat; diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 44a4c10348..353c6ca25d 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -244,8 +244,17 @@ impl Series { None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), } } + DataType::Utf8 => { + let downcasted = self.downcast::()?; + match groups { + Some(groups) => { + Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series()) + } + None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), + } + } _ => Err(DaftError::TypeError(format!( - "concat aggregation is only valid for List or Python types, got {}", + "concat aggregation is only valid for List, Python types, or Utf8, got {}", self.data_type() ))), } diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index affb5f08e3..f8c5deb247 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -390,6 +390,7 @@ impl AggExpr { let field = expr.to_field(schema)?; match field.dtype { DataType::List(..) => Ok(field), + DataType::Utf8 => Ok(field), #[cfg(feature = "python")] DataType::Python => Ok(field), _ => Err(DaftError::TypeError(format!( diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index fa7a26b3e4..01749a1cdb 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -874,3 +874,53 @@ def test_groupby_struct(dtype) -> None: expected = [[0, 1, 4], [2, 6], [3, 5]] for lt in expected: assert lt in res["b"] + + +def test_agg_concat_on_string() -> None: + df3 = from_pydict({"a": ["the", " quick", " brown", " fox"]}) + res = df3.agg(col("a").agg_concat()).to_pydict() + assert res["a"] == ["the quick brown fox"] + + +def test_agg_concat_on_string_groupby() -> None: + df3 = from_pydict({"a": ["the", " quick", " brown", " fox"], "b": [1, 2, 1, 2]}) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = ["the brown", " quick fox"] + for txt in expected: + assert txt in res["a"] + + +def test_agg_concat_on_string_null() -> None: + df3 = from_pydict({"a": ["the", " quick", None, " fox"]}) + res = df3.agg(col("a").agg_concat()).to_pydict() + expected = ["the quick fox"] + assert res["a"] == expected + + +def test_agg_concat_on_string_groupby_null() -> None: + df3 = from_pydict({"a": ["the", " quick", None, " fox"], "b": [1, 2, 1, 2]}) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = ["the", " quick fox"] + for txt in expected: + assert txt in res["a"] + + +def test_agg_concat_on_string_null_list() -> None: + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column( + "a", col("a").cast(DataType.string()) + ) + res = df3.agg(col("a").agg_concat()).to_pydict() + print(res) + expected = [None] + assert res["a"] == expected + assert len(res["a"]) == 1 + + +def test_agg_concat_on_string_groupby_null_list() -> None: + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column( + "a", col("a").cast(DataType.string()) + ) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = [None, None] + assert res["a"] == expected + assert len(res["a"]) == len(expected) From 45e2944e252ccdd563dc20edd9b29762e05cec1d Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 25 Sep 2024 11:55:02 -0700 Subject: [PATCH 29/35] [CHORE] auto-fix prefer `Self` over explicit type (#2908) --- Cargo.toml | 3 + src/common/arrow-ffi/Cargo.toml | 3 + src/common/daft-config/Cargo.toml | 3 + src/common/daft-config/src/lib.rs | 2 +- src/common/daft-config/src/python.rs | 19 +- src/common/display/Cargo.toml | 3 + src/common/error/Cargo.toml | 3 + src/common/error/src/python.rs | 2 +- src/common/file-formats/Cargo.toml | 3 + src/common/hashable-float-wrapper/Cargo.toml | 3 + src/common/io-config/Cargo.toml | 3 + src/common/io-config/src/http.rs | 4 +- src/common/io-config/src/lib.rs | 4 +- src/common/io-config/src/python.rs | 22 +- src/common/io-config/src/s3.rs | 2 +- src/common/py-serde/Cargo.toml | 3 + src/common/resource-request/Cargo.toml | 3 + src/common/resource-request/src/lib.rs | 14 +- src/common/system-info/Cargo.toml | 3 + src/common/system-info/src/lib.rs | 2 +- src/common/tracing/Cargo.toml | 3 + src/common/treenode/Cargo.toml | 3 + src/common/treenode/src/lib.rs | 46 +-- src/common/version/Cargo.toml | 3 + src/daft-compression/Cargo.toml | 3 + src/daft-core/Cargo.toml | 3 + .../src/array/fixed_size_list_array.rs | 4 +- src/daft-core/src/array/from.rs | 38 +- src/daft-core/src/array/from_iter.rs | 18 +- src/daft-core/src/array/image_array.rs | 2 +- src/daft-core/src/array/list_array.rs | 4 +- src/daft-core/src/array/mod.rs | 10 +- src/daft-core/src/array/ops/apply.rs | 18 +- .../src/array/ops/approx_count_distinct.rs | 2 +- src/daft-core/src/array/ops/arange.rs | 2 +- src/daft-core/src/array/ops/between.rs | 4 +- src/daft-core/src/array/ops/bitwise.rs | 8 +- src/daft-core/src/array/ops/broadcast.rs | 14 +- src/daft-core/src/array/ops/cast.rs | 6 +- src/daft-core/src/array/ops/compare_agg.rs | 42 +-- src/daft-core/src/array/ops/comparison.rs | 230 ++++------- src/daft-core/src/array/ops/concat.rs | 8 +- src/daft-core/src/array/ops/concat_agg.rs | 13 +- src/daft-core/src/array/ops/filter.rs | 4 +- src/daft-core/src/array/ops/from_arrow.rs | 12 +- src/daft-core/src/array/ops/full.rs | 8 +- src/daft-core/src/array/ops/if_else.rs | 10 +- src/daft-core/src/array/ops/is_in.rs | 8 +- src/daft-core/src/array/ops/list_agg.rs | 8 +- src/daft-core/src/array/ops/sort.rs | 8 +- src/daft-core/src/array/ops/take.rs | 19 +- src/daft-core/src/array/ops/trigonometry.rs | 24 +- src/daft-core/src/array/ops/truncate.rs | 10 +- src/daft-core/src/array/ops/utf8.rs | 131 +++---- .../src/array/pseudo_arrow/compute.rs | 2 +- src/daft-core/src/array/pseudo_arrow/mod.rs | 2 +- .../src/array/pseudo_arrow/python.rs | 13 +- src/daft-core/src/array/serdes.rs | 2 +- src/daft-core/src/array/struct_array.rs | 4 +- src/daft-core/src/count_mode.rs | 10 +- src/daft-core/src/datatypes/logical.rs | 2 +- src/daft-core/src/join.rs | 26 +- src/daft-core/src/python/series.rs | 4 +- .../src/series/array_impl/data_array.rs | 2 +- .../src/series/array_impl/logical_array.rs | 2 +- src/daft-core/src/series/mod.rs | 4 +- src/daft-core/src/series/ops/abs.rs | 2 +- src/daft-core/src/series/ops/agg.rs | 28 +- src/daft-core/src/series/ops/between.rs | 8 +- src/daft-core/src/series/ops/broadcast.rs | 2 +- src/daft-core/src/series/ops/cast.rs | 2 +- src/daft-core/src/series/ops/cbrt.rs | 2 +- src/daft-core/src/series/ops/ceil.rs | 2 +- src/daft-core/src/series/ops/comparison.rs | 6 +- src/daft-core/src/series/ops/concat.rs | 2 +- src/daft-core/src/series/ops/exp.rs | 2 +- src/daft-core/src/series/ops/filter.rs | 2 +- src/daft-core/src/series/ops/float.rs | 6 +- src/daft-core/src/series/ops/floor.rs | 2 +- src/daft-core/src/series/ops/if_else.rs | 2 +- src/daft-core/src/series/ops/is_in.rs | 2 +- src/daft-core/src/series/ops/list.rs | 18 +- src/daft-core/src/series/ops/log.rs | 8 +- src/daft-core/src/series/ops/map.rs | 2 +- src/daft-core/src/series/ops/minhash.rs | 2 +- src/daft-core/src/series/ops/not.rs | 2 +- src/daft-core/src/series/ops/null.rs | 6 +- src/daft-core/src/series/ops/repeat.rs | 2 +- src/daft-core/src/series/ops/round.rs | 2 +- src/daft-core/src/series/ops/shift.rs | 4 +- src/daft-core/src/series/ops/sign.rs | 2 +- .../src/series/ops/sketch_percentile.rs | 2 +- src/daft-core/src/series/ops/sort.rs | 4 +- src/daft-core/src/series/ops/sqrt.rs | 2 +- src/daft-core/src/series/ops/struct_.rs | 2 +- src/daft-core/src/series/ops/take.rs | 6 +- src/daft-core/src/series/ops/utf8.rs | 69 ++-- src/daft-csv/Cargo.toml | 3 + src/daft-csv/src/lib.rs | 6 +- src/daft-decoding/Cargo.toml | 3 + src/daft-dsl/Cargo.toml | 3 + src/daft-dsl/src/expr.rs | 56 +-- src/daft-dsl/src/functions/python/mod.rs | 4 +- src/daft-dsl/src/lit.rs | 18 +- src/daft-dsl/src/python.rs | 4 +- src/daft-functions-json/Cargo.toml | 3 + src/daft-functions/Cargo.toml | 3 + src/daft-functions/src/distance/cosine.rs | 6 +- src/daft-functions/src/lib.rs | 8 +- src/daft-functions/src/tokenize/bpe.rs | 20 +- src/daft-functions/src/uri/download.rs | 2 +- src/daft-functions/src/uri/upload.rs | 2 +- src/daft-image/Cargo.toml | 3 + src/daft-image/src/image_buffer.rs | 7 +- src/daft-io/Cargo.toml | 3 + src/daft-io/src/azure_blob.rs | 14 +- src/daft-io/src/google_cloud.rs | 24 +- src/daft-io/src/http.rs | 10 +- src/daft-io/src/huggingface.rs | 14 +- src/daft-io/src/lib.rs | 44 +-- src/daft-io/src/local.rs | 16 +- src/daft-io/src/object_io.rs | 8 +- src/daft-io/src/object_store_glob.rs | 12 +- src/daft-io/src/s3_like.rs | 56 +-- src/daft-io/src/stats.rs | 2 +- src/daft-json/Cargo.toml | 3 + src/daft-json/src/lib.rs | 6 +- src/daft-local-execution/Cargo.toml | 3 + src/daft-local-execution/src/channel.rs | 4 +- .../anti_semi_hash_join_probe.rs | 6 +- .../src/intermediate_ops/hash_join_probe.rs | 6 +- src/daft-local-execution/src/lib.rs | 4 +- src/daft-local-execution/src/pipeline.rs | 10 +- .../src/sinks/blocking_sink.rs | 2 +- .../src/sinks/streaming_sink.rs | 2 +- src/daft-micropartition/Cargo.toml | 3 + src/daft-micropartition/src/lib.rs | 2 +- src/daft-micropartition/src/micropartition.rs | 8 +- src/daft-micropartition/src/ops/agg.rs | 4 +- .../src/ops/cast_to_schema.rs | 4 +- src/daft-micropartition/src/ops/concat.rs | 2 +- .../src/ops/eval_expressions.rs | 4 +- src/daft-micropartition/src/ops/join.rs | 2 +- src/daft-micropartition/src/ops/partition.rs | 18 +- src/daft-micropartition/src/ops/pivot.rs | 2 +- src/daft-micropartition/src/ops/slice.rs | 2 +- src/daft-micropartition/src/python.rs | 2 +- src/daft-minhash/Cargo.toml | 3 + src/daft-parquet/Cargo.toml | 3 + src/daft-parquet/src/file.rs | 6 +- src/daft-parquet/src/lib.rs | 8 +- src/daft-parquet/src/metadata.rs | 24 +- src/daft-parquet/src/read.rs | 8 +- src/daft-parquet/src/read_planner.rs | 2 +- src/daft-parquet/src/statistics/mod.rs | 6 +- src/daft-physical-plan/Cargo.toml | 3 + src/daft-physical-plan/src/local_plan.rs | 40 +- src/daft-plan/Cargo.toml | 3 + src/daft-plan/src/builder.rs | 8 +- src/daft-plan/src/display.rs | 62 +-- .../src/logical_ops/actor_pool_project.rs | 2 +- .../src/logical_optimization/optimizer.rs | 4 +- src/daft-plan/src/logical_plan.rs | 8 +- src/daft-plan/src/partitioning.rs | 2 +- .../src/physical_ops/actor_pool_project.rs | 2 +- .../src/physical_optimization/optimizer.rs | 8 +- .../src/physical_optimization/rules/rule.rs | 2 +- src/daft-plan/src/physical_plan.rs | 8 +- src/daft-plan/src/physical_planner/planner.rs | 10 +- src/daft-plan/src/source_info/mod.rs | 2 +- src/daft-scan/Cargo.toml | 3 + src/daft-scan/src/glob.rs | 2 +- src/daft-scan/src/lib.rs | 16 +- src/daft-scan/src/python.rs | 14 +- src/daft-scan/src/storage_config.rs | 4 +- src/daft-scheduler/Cargo.toml | 3 + src/daft-scheduler/src/adaptive.rs | 7 +- src/daft-schema/Cargo.toml | 3 + src/daft-schema/src/dtype.rs | 357 +++++++++--------- src/daft-schema/src/image_format.rs | 4 +- src/daft-schema/src/image_mode.rs | 8 +- src/daft-schema/src/python/datatype.rs | 10 +- src/daft-schema/src/python/field.rs | 4 +- src/daft-schema/src/python/schema.rs | 18 +- src/daft-schema/src/schema.rs | 18 +- src/daft-schema/src/time_unit.rs | 24 +- src/daft-sketch/Cargo.toml | 3 + src/daft-sketch/src/arrow2_serde.rs | 2 +- src/daft-sql/Cargo.toml | 3 + src/daft-sql/src/catalog.rs | 4 +- src/daft-sql/src/error.rs | 16 +- src/daft-sql/src/modules/aggs.rs | 2 +- src/daft-sql/src/modules/partitioning.rs | 20 +- src/daft-sql/src/planner.rs | 6 +- src/daft-sql/src/python.rs | 4 +- src/daft-stats/Cargo.toml | 3 + src/daft-stats/src/column_stats/comparison.rs | 82 ++-- src/daft-stats/src/column_stats/mod.rs | 22 +- src/daft-stats/src/lib.rs | 2 +- src/daft-stats/src/table_stats.rs | 12 +- src/daft-table/Cargo.toml | 3 + src/daft-table/src/lib.rs | 32 +- src/daft-table/src/ops/agg.rs | 14 +- src/daft-table/src/ops/joins/mod.rs | 2 +- src/daft-table/src/ops/pivot.rs | 2 +- src/daft-table/src/ops/sort.rs | 2 +- src/daft-table/src/ops/unpivot.rs | 2 +- src/daft-table/src/python.rs | 17 +- src/hyperloglog/Cargo.toml | 3 + 209 files changed, 1205 insertions(+), 1331 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a0f4cc19dc..6af21986e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -224,6 +224,9 @@ version = "0.11.0" features = ["derive", "rc"] version = "1.0.200" +[workspace.lints.clippy] +use-self = "deny" + [workspace.package] edition = "2021" version = "0.3.0-dev0" diff --git a/src/common/arrow-ffi/Cargo.toml b/src/common/arrow-ffi/Cargo.toml index b45af25939..88575a3e1e 100644 --- a/src/common/arrow-ffi/Cargo.toml +++ b/src/common/arrow-ffi/Cargo.toml @@ -5,6 +5,9 @@ pyo3 = {workspace = true, optional = true} [features] python = ["dep:pyo3"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-arrow-ffi" diff --git a/src/common/daft-config/Cargo.toml b/src/common/daft-config/Cargo.toml index 6a208fdcae..212ce37ab2 100644 --- a/src/common/daft-config/Cargo.toml +++ b/src/common/daft-config/Cargo.toml @@ -7,6 +7,9 @@ serde = {workspace = true} [features] python = ["dep:pyo3", "common-io-config/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-daft-config" diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index dcaef0a2f8..077d4b7e83 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -60,7 +60,7 @@ pub struct DaftExecutionConfig { impl Default for DaftExecutionConfig { fn default() -> Self { - DaftExecutionConfig { + Self { scan_tasks_min_size_bytes: 96 * 1024 * 1024, // 96MB scan_tasks_max_size_bytes: 384 * 1024 * 1024, // 384MB broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 5dda71eda8..44bb95c1b0 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -17,27 +17,24 @@ pub struct PyDaftPlanningConfig { impl PyDaftPlanningConfig { #[new] pub fn new() -> Self { - PyDaftPlanningConfig::default() + Self::default() } #[staticmethod] pub fn from_env() -> Self { - PyDaftPlanningConfig { + Self { config: Arc::new(DaftPlanningConfig::from_env()), } } - fn with_config_values( - &mut self, - default_io_config: Option, - ) -> PyResult { + fn with_config_values(&mut self, default_io_config: Option) -> PyResult { let mut config = self.config.as_ref().clone(); if let Some(default_io_config) = default_io_config { config.default_io_config = default_io_config.config; } - Ok(PyDaftPlanningConfig { + Ok(Self { config: Arc::new(config), }) } @@ -67,12 +64,12 @@ pub struct PyDaftExecutionConfig { impl PyDaftExecutionConfig { #[new] pub fn new() -> Self { - PyDaftExecutionConfig::default() + Self::default() } #[staticmethod] pub fn from_env() -> Self { - PyDaftExecutionConfig { + Self { config: Arc::new(DaftExecutionConfig::from_env()), } } @@ -98,7 +95,7 @@ impl PyDaftExecutionConfig { enable_aqe: Option, enable_native_executor: Option, default_morsel_size: Option, - ) -> PyResult { + ) -> PyResult { let mut config = self.config.as_ref().clone(); if let Some(scan_tasks_max_size_bytes) = scan_tasks_max_size_bytes { @@ -161,7 +158,7 @@ impl PyDaftExecutionConfig { config.default_morsel_size = default_morsel_size; } - Ok(PyDaftExecutionConfig { + Ok(Self { config: Arc::new(config), }) } diff --git a/src/common/display/Cargo.toml b/src/common/display/Cargo.toml index 3fe4ea2774..55fd06966c 100644 --- a/src/common/display/Cargo.toml +++ b/src/common/display/Cargo.toml @@ -8,6 +8,9 @@ textwrap = {version = "0.16.1"} [features] python = ["dep:pyo3"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-display" diff --git a/src/common/error/Cargo.toml b/src/common/error/Cargo.toml index 8da47b9fee..b64ef5c901 100644 --- a/src/common/error/Cargo.toml +++ b/src/common/error/Cargo.toml @@ -8,6 +8,9 @@ thiserror = {workspace = true} [features] python = ["dep:pyo3"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-error" diff --git a/src/common/error/src/python.rs b/src/common/error/src/python.rs index b6b4e48523..917dafdc78 100644 --- a/src/common/error/src/python.rs +++ b/src/common/error/src/python.rs @@ -10,7 +10,7 @@ import_exception!(daft.exceptions, ByteStreamError); import_exception!(daft.exceptions, SocketError); impl std::convert::From for pyo3::PyErr { - fn from(err: DaftError) -> pyo3::PyErr { + fn from(err: DaftError) -> Self { match err { DaftError::PyO3Error(pyerr) => pyerr, DaftError::FileNotFound { path, source } => { diff --git a/src/common/file-formats/Cargo.toml b/src/common/file-formats/Cargo.toml index 9dc2121eb2..2b3495b6d1 100644 --- a/src/common/file-formats/Cargo.toml +++ b/src/common/file-formats/Cargo.toml @@ -9,6 +9,9 @@ serde_json = {workspace = true, optional = true} [features] python = ["dep:pyo3", "dep:serde_json", "common-error/python", "common-py-serde/python", "daft-schema/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-file-formats" diff --git a/src/common/hashable-float-wrapper/Cargo.toml b/src/common/hashable-float-wrapper/Cargo.toml index ce370ed965..541535f260 100644 --- a/src/common/hashable-float-wrapper/Cargo.toml +++ b/src/common/hashable-float-wrapper/Cargo.toml @@ -1,6 +1,9 @@ [dependencies] serde = {workspace = true} +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-hashable-float-wrapper" diff --git a/src/common/io-config/Cargo.toml b/src/common/io-config/Cargo.toml index b1273b4499..a66e9bfa23 100644 --- a/src/common/io-config/Cargo.toml +++ b/src/common/io-config/Cargo.toml @@ -12,6 +12,9 @@ typetag = "0.2.16" [features] python = ["dep:pyo3", "common-error/python", "common-py-serde/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-io-config" diff --git a/src/common/io-config/src/http.rs b/src/common/io-config/src/http.rs index 275c55f106..6241de3028 100644 --- a/src/common/io-config/src/http.rs +++ b/src/common/io-config/src/http.rs @@ -12,7 +12,7 @@ pub struct HTTPConfig { impl Default for HTTPConfig { fn default() -> Self { - HTTPConfig { + Self { user_agent: "daft/0.0.1".to_string(), // NOTE: Ideally we grab the version of Daft, but that requires a dependency on daft-core bearer_token: None, } @@ -21,7 +21,7 @@ impl Default for HTTPConfig { impl HTTPConfig { pub fn new>(bearer_token: Option) -> Self { - HTTPConfig { + Self { bearer_token: bearer_token.map(|t| t.into()), ..Default::default() } diff --git a/src/common/io-config/src/lib.rs b/src/common/io-config/src/lib.rs index b50b4e4185..46e4de278d 100644 --- a/src/common/io-config/src/lib.rs +++ b/src/common/io-config/src/lib.rs @@ -73,12 +73,12 @@ impl<'de> Deserialize<'de> for ObfuscatedString { D: Deserializer<'de>, { let s = String::deserialize(deserializer)?; - Ok(ObfuscatedString(s.into())) + Ok(Self(s.into())) } } impl From for ObfuscatedString { fn from(value: String) -> Self { - ObfuscatedString(value.into()) + Self(value.into()) } } diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index ef1276bbc6..6ae67a2443 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -154,7 +154,7 @@ impl IOConfig { gcs: Option, http: Option, ) -> Self { - IOConfig { + Self { config: config::IOConfig { s3: s3.unwrap_or_default().config, azure: azure.unwrap_or_default().config, @@ -171,7 +171,7 @@ impl IOConfig { gcs: Option, http: Option, ) -> Self { - IOConfig { + Self { config: config::IOConfig { s3: s3.map(|s3| s3.config).unwrap_or(self.config.s3.clone()), azure: azure @@ -274,7 +274,7 @@ impl S3Config { profile_name: Option, ) -> PyResult { let def = crate::S3Config::default(); - Ok(S3Config { + Ok(Self { config: crate::S3Config { region_name: region_name.or(def.region_name), endpoint_url: endpoint_url.or(def.endpoint_url), @@ -333,7 +333,7 @@ impl S3Config { force_virtual_addressing: Option, profile_name: Option, ) -> PyResult { - Ok(S3Config { + Ok(Self { config: crate::S3Config { region_name: region_name.or_else(|| self.config.region_name.clone()), endpoint_url: endpoint_url.or_else(|| self.config.endpoint_url.clone()), @@ -545,7 +545,7 @@ impl S3Credentials { }) .transpose()?; - Ok(S3Credentials { + Ok(Self { credentials: crate::S3Credentials { key_id, access_key, @@ -606,7 +606,7 @@ pub struct PyS3CredentialsProvider { impl PyS3CredentialsProvider { pub fn new(provider: Bound) -> PyResult { let hash = provider.hash()?; - Ok(PyS3CredentialsProvider { + Ok(Self { provider: provider.into(), hash, }) @@ -693,7 +693,7 @@ impl AzureConfig { use_ssl: Option, ) -> Self { let def = crate::AzureConfig::default(); - AzureConfig { + Self { config: crate::AzureConfig { storage_account: storage_account.or(def.storage_account), access_key: access_key.map(|v| v.into()).or(def.access_key), @@ -725,7 +725,7 @@ impl AzureConfig { endpoint_url: Option, use_ssl: Option, ) -> Self { - AzureConfig { + Self { config: crate::AzureConfig { storage_account: storage_account.or_else(|| self.config.storage_account.clone()), access_key: access_key @@ -835,7 +835,7 @@ impl GCSConfig { anonymous: Option, ) -> Self { let def = crate::GCSConfig::default(); - GCSConfig { + Self { config: crate::GCSConfig { project_id: project_id.or(def.project_id), credentials: credentials.map(|v| v.into()).or(def.credentials), @@ -852,7 +852,7 @@ impl GCSConfig { token: Option, anonymous: Option, ) -> Self { - GCSConfig { + Self { config: crate::GCSConfig { project_id: project_id.or_else(|| self.config.project_id.clone()), credentials: credentials @@ -907,7 +907,7 @@ impl From for IOConfig { impl HTTPConfig { #[new] pub fn new(bearer_token: Option) -> Self { - HTTPConfig { + Self { config: crate::HTTPConfig::new(bearer_token), } } diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index a6e4fc97b5..cb02fad7fb 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -138,7 +138,7 @@ impl S3Config { impl Default for S3Config { fn default() -> Self { - S3Config { + Self { region_name: None, endpoint_url: None, key_id: None, diff --git a/src/common/py-serde/Cargo.toml b/src/common/py-serde/Cargo.toml index e799b0a574..60117ef209 100644 --- a/src/common/py-serde/Cargo.toml +++ b/src/common/py-serde/Cargo.toml @@ -6,6 +6,9 @@ serde = {workspace = true} [features] python = ["dep:pyo3"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-py-serde" diff --git a/src/common/resource-request/Cargo.toml b/src/common/resource-request/Cargo.toml index a2db514585..d72d63b796 100644 --- a/src/common/resource-request/Cargo.toml +++ b/src/common/resource-request/Cargo.toml @@ -7,6 +7,9 @@ serde = {workspace = true} [features] python = ["dep:pyo3", "common-py-serde/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-resource-request" diff --git a/src/common/resource-request/src/lib.rs b/src/common/resource-request/src/lib.rs index 9367a7af06..a422c91475 100644 --- a/src/common/resource-request/src/lib.rs +++ b/src/common/resource-request/src/lib.rs @@ -85,7 +85,7 @@ impl ResourceRequest { /// /// Currently, this returns true unless one resource request has a non-zero CPU request and the other task has a /// non-zero GPU request. - pub fn is_pipeline_compatible_with(&self, other: &ResourceRequest) -> bool { + pub fn is_pipeline_compatible_with(&self, other: &Self) -> bool { let self_num_cpus = self.num_cpus; let self_num_gpus = self.num_gpus; let other_num_cpus = other.num_cpus; @@ -100,7 +100,7 @@ impl ResourceRequest { } } - pub fn max(&self, other: &ResourceRequest) -> Self { + pub fn max(&self, other: &Self) -> Self { let max_num_cpus = lift(float_max, self.num_cpus, other.num_cpus); let max_num_gpus = lift(float_max, self.num_gpus, other.num_gpus); let max_memory_bytes = lift(std::cmp::max, self.memory_bytes, other.memory_bytes); @@ -152,8 +152,8 @@ impl Hash for ResourceRequest { } } -impl AsRef for ResourceRequest { - fn as_ref(&self) -> &ResourceRequest { +impl AsRef for ResourceRequest { + fn as_ref(&self) -> &Self { self } } @@ -200,21 +200,21 @@ impl ResourceRequest { } pub fn with_num_cpus(&self, num_cpus: Option) -> Self { - ResourceRequest { + Self { num_cpus, ..self.clone() } } pub fn with_num_gpus(&self, num_gpus: Option) -> Self { - ResourceRequest { + Self { num_gpus, ..self.clone() } } pub fn with_memory_bytes(&self, memory_bytes: Option) -> Self { - ResourceRequest { + Self { memory_bytes, ..self.clone() } diff --git a/src/common/system-info/Cargo.toml b/src/common/system-info/Cargo.toml index 6f1f367ee7..6548cca9db 100644 --- a/src/common/system-info/Cargo.toml +++ b/src/common/system-info/Cargo.toml @@ -5,6 +5,9 @@ sysinfo = "0.30.7" [features] python = ["dep:pyo3"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-system-info" diff --git a/src/common/system-info/src/lib.rs b/src/common/system-info/src/lib.rs index 237e75ba8f..3ef6ba180e 100644 --- a/src/common/system-info/src/lib.rs +++ b/src/common/system-info/src/lib.rs @@ -9,7 +9,7 @@ pub struct SystemInfo { impl Default for SystemInfo { fn default() -> Self { - SystemInfo { + Self { info: sysinfo::System::new_with_specifics( RefreshKind::new() .with_cpu(CpuRefreshKind::everything()) diff --git a/src/common/tracing/Cargo.toml b/src/common/tracing/Cargo.toml index e4f1c87c7a..d72c27a8dc 100644 --- a/src/common/tracing/Cargo.toml +++ b/src/common/tracing/Cargo.toml @@ -4,6 +4,9 @@ tracing = {workspace = true} tracing-chrome = "0.7.2" tracing-subscriber = "0.3" +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-tracing" diff --git a/src/common/treenode/Cargo.toml b/src/common/treenode/Cargo.toml index 15e2771271..2a7ebda4f4 100644 --- a/src/common/treenode/Cargo.toml +++ b/src/common/treenode/Cargo.toml @@ -4,6 +4,9 @@ common-error = {path = "../error", default-features = false} [features] python = ["common-error/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-treenode" diff --git a/src/common/treenode/src/lib.rs b/src/common/treenode/src/lib.rs index 68da6c47c8..2507de4986 100644 --- a/src/common/treenode/src/lib.rs +++ b/src/common/treenode/src/lib.rs @@ -540,38 +540,29 @@ pub enum TreeNodeRecursion { impl TreeNodeRecursion { /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`] /// value and the fact that `f` is visiting the current node's children. - pub fn visit_children Result>( - self, - f: F, - ) -> Result { + pub fn visit_children Result>(self, f: F) -> Result { match self { - TreeNodeRecursion::Continue => f(), - TreeNodeRecursion::Jump => Ok(TreeNodeRecursion::Continue), - TreeNodeRecursion::Stop => Ok(self), + Self::Continue => f(), + Self::Jump => Ok(Self::Continue), + Self::Stop => Ok(self), } } /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`] /// value and the fact that `f` is visiting the current node's sibling. - pub fn visit_sibling Result>( - self, - f: F, - ) -> Result { + pub fn visit_sibling Result>(self, f: F) -> Result { match self { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => f(), - TreeNodeRecursion::Stop => Ok(self), + Self::Continue | Self::Jump => f(), + Self::Stop => Ok(self), } } /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`] /// value and the fact that `f` is visiting the current node's parent. - pub fn visit_parent Result>( - self, - f: F, - ) -> Result { + pub fn visit_parent Result>(self, f: F) -> Result { match self { - TreeNodeRecursion::Continue => f(), - TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self), + Self::Continue => f(), + Self::Jump | Self::Stop => Ok(self), } } } @@ -670,10 +661,7 @@ impl Transformed { /// Maps the [`Transformed`] object to the result of the given `f` depending on the /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current /// node's children. - pub fn transform_children Result>>( - mut self, - f: F, - ) -> Result> { + pub fn transform_children Result>(mut self, f: F) -> Result { match self.tnr { TreeNodeRecursion::Continue => { return f(self.data).map(|mut t| { @@ -692,10 +680,7 @@ impl Transformed { /// Maps the [`Transformed`] object to the result of the given `f` depending on the /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current /// node's sibling. - pub fn transform_sibling Result>>( - self, - f: F, - ) -> Result> { + pub fn transform_sibling Result>(self, f: F) -> Result { match self.tnr { TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => f(self.data).map(|mut t| { t.transformed |= self.transformed; @@ -708,10 +693,7 @@ impl Transformed { /// Maps the [`Transformed`] object to the result of the given `f` depending on the /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current /// node's parent. - pub fn transform_parent Result>>( - self, - f: F, - ) -> Result> { + pub fn transform_parent Result>(self, f: F) -> Result { match self.tnr { TreeNodeRecursion::Continue => f(self.data).map(|mut t| { t.transformed |= self.transformed; @@ -951,7 +933,7 @@ mod tests { } impl TestTreeNode { - fn new(children: Vec>, data: T) -> Self { + fn new(children: Vec, data: T) -> Self { Self { children, data } } } diff --git a/src/common/version/Cargo.toml b/src/common/version/Cargo.toml index 69f162f811..be36a90711 100644 --- a/src/common/version/Cargo.toml +++ b/src/common/version/Cargo.toml @@ -1,3 +1,6 @@ +[lints] +workspace = true + [package] edition = {workspace = true} name = "common-version" diff --git a/src/daft-compression/Cargo.toml b/src/daft-compression/Cargo.toml index 6b695535e5..9ce9b1f862 100644 --- a/src/daft-compression/Cargo.toml +++ b/src/daft-compression/Cargo.toml @@ -3,6 +3,9 @@ async-compression = {workspace = true} tokio = {workspace = true} url = {workspace = true} +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-compression" diff --git a/src/daft-core/Cargo.toml b/src/daft-core/Cargo.toml index 36ea9b68df..ec15924316 100644 --- a/src/daft-core/Cargo.toml +++ b/src/daft-core/Cargo.toml @@ -63,6 +63,9 @@ python = [ "daft-schema/python" ] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-core" diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index d265f42929..a8b5048b82 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -53,7 +53,7 @@ impl FixedSizeListArray { field ), } - FixedSizeListArray { + Self { field, flat_child, validity, @@ -90,7 +90,7 @@ impl FixedSizeListArray { growable .build() - .map(|s| s.downcast::().unwrap().clone()) + .map(|s| s.downcast::().unwrap().clone()) } pub fn len(&self) -> usize { diff --git a/src/daft-core/src/array/from.rs b/src/daft-core/src/array/from.rs index 3ef75a23a7..b48c16a4ba 100644 --- a/src/daft-core/src/array/from.rs +++ b/src/daft-core/src/array/from.rs @@ -15,35 +15,35 @@ impl From<(&str, Box>)) -> Self { let (name, array) = item; - DataArray::new(Field::new(name, T::get_dtype()).into(), array).unwrap() + Self::new(Field::new(name, T::get_dtype()).into(), array).unwrap() } } impl From<(&str, Box)> for NullArray { fn from(item: (&str, Box)) -> Self { let (name, array) = item; - DataArray::new(Field::new(name, DataType::Null).into(), array).unwrap() + Self::new(Field::new(name, DataType::Null).into(), array).unwrap() } } impl From<(&str, Box>)> for Utf8Array { fn from(item: (&str, Box>)) -> Self { let (name, array) = item; - DataArray::new(Field::new(name, DataType::Utf8).into(), array).unwrap() + Self::new(Field::new(name, DataType::Utf8).into(), array).unwrap() } } impl From<(&str, Box>)> for BinaryArray { fn from(item: (&str, Box>)) -> Self { let (name, array) = item; - DataArray::new(Field::new(name, DataType::Binary).into(), array).unwrap() + Self::new(Field::new(name, DataType::Binary).into(), array).unwrap() } } impl From<(&str, Box)> for FixedSizeBinaryArray { fn from(item: (&str, Box)) -> Self { let (name, array) = item; - DataArray::new( + Self::new( Field::new(name, DataType::FixedSizeBinary(array.size())).into(), array, ) @@ -58,7 +58,7 @@ where fn from((name, array, length): (&str, I, usize)) -> Self { let array = Cow::from(array); let array = array.into_owned(); - DataArray::new( + Self::new( Field::new(name, DataType::FixedSizeBinary(length)).into(), Box::new(arrow2::array::FixedSizeBinaryArray::new( arrow2::datatypes::DataType::FixedSizeBinary(length), @@ -79,7 +79,7 @@ where let arrow_array = Box::new(arrow2::array::PrimitiveArray::::from_slice( slice, )); - DataArray::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() + Self::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() } } @@ -90,7 +90,7 @@ where fn from(item: (&str, Vec)) -> Self { let (name, v) = item; let arrow_array = Box::new(arrow2::array::PrimitiveArray::::from_vec(v)); - DataArray::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() + Self::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() } } @@ -98,7 +98,7 @@ impl From<(&str, &[bool])> for BooleanArray { fn from(item: (&str, &[bool])) -> Self { let (name, slice) = item; let arrow_array = Box::new(arrow2::array::BooleanArray::from_slice(slice)); - DataArray::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() + Self::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() } } @@ -108,14 +108,14 @@ impl From<(&str, &[Option])> for BooleanArray { let arrow_array = Box::new(arrow2::array::BooleanArray::from_trusted_len_iter( slice.iter().cloned(), )); - DataArray::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() + Self::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() } } impl From<(&str, arrow2::array::BooleanArray)> for BooleanArray { fn from(item: (&str, arrow2::array::BooleanArray)) -> Self { let (name, arrow_array) = item; - DataArray::new( + Self::new( Field::new(name, DataType::Boolean).into(), Box::new(arrow_array), ) @@ -126,7 +126,7 @@ impl From<(&str, arrow2::array::BooleanArray)> for BooleanArray { impl From<(&str, arrow2::bitmap::Bitmap)> for BooleanArray { fn from(item: (&str, arrow2::bitmap::Bitmap)) -> Self { let (name, bitmap) = item; - DataArray::new( + Self::new( Field::new(name, DataType::Boolean).into(), Box::new(arrow2::array::BooleanArray::new( arrow2::datatypes::DataType::Boolean, @@ -141,7 +141,7 @@ impl From<(&str, arrow2::bitmap::Bitmap)> for BooleanArray { impl From<(&str, Box)> for BooleanArray { fn from(item: (&str, Box)) -> Self { let (name, arrow_array) = item; - DataArray::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() + Self::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() } } @@ -155,7 +155,7 @@ impl From<(&str, Vec)> for crate::datatypes::PythonArray { PseudoArrowArray::::from_pyobj_vec(vec_pyobj), ); let field = Field::new(name, DataType::Python); - DataArray::new(field.into(), arrow_array).unwrap() + Self::new(field.into(), arrow_array).unwrap() } } @@ -163,7 +163,7 @@ impl> From<(&str, &[T])> for DataArray { fn from(item: (&str, &[T])) -> Self { let (name, slice) = item; let arrow_array = Box::new(arrow2::array::Utf8Array::::from_slice(slice)); - DataArray::new(Field::new(name, DataType::Utf8).into(), arrow_array).unwrap() + Self::new(Field::new(name, DataType::Utf8).into(), arrow_array).unwrap() } } @@ -171,7 +171,7 @@ impl From<(&str, &[u8])> for BinaryArray { fn from(item: (&str, &[u8])) -> Self { let (name, slice) = item; let arrow_array = Box::new(arrow2::array::BinaryArray::::from_slice([slice])); - DataArray::new(Field::new(name, DataType::Binary).into(), arrow_array).unwrap() + Self::new(Field::new(name, DataType::Binary).into(), arrow_array).unwrap() } } @@ -183,7 +183,7 @@ impl>> TryFrom<(F, Box)) -> DaftResult { let (field, array) = item; let field: Arc = field.into(); - DataArray::new(field, array) + Self::new(field, array) } } @@ -211,7 +211,7 @@ impl TryFrom<(&str, Vec, Vec)> for BinaryArray { data.into(), None, )?; - DataArray::new( + Self::new( Field::new(name, DataType::Binary).into(), Box::new(bin_array), ) @@ -234,6 +234,6 @@ impl ), ) -> DaftResult { let (name, array) = item; - DataArray::new(Field::new(name, DataType::Python).into(), Box::new(array)) + Self::new(Field::new(name, DataType::Python).into(), Box::new(array)) } } diff --git a/src/daft-core/src/array/from_iter.rs b/src/daft-core/src/array/from_iter.rs index 484b676f30..45e0e112a9 100644 --- a/src/daft-core/src/array/from_iter.rs +++ b/src/daft-core/src/array/from_iter.rs @@ -11,7 +11,7 @@ where ) -> Self { let arrow_array = Box::new(arrow2::array::PrimitiveArray::::from_trusted_len_iter(iter)); - DataArray::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() + Self::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() } } @@ -21,7 +21,7 @@ impl Utf8Array { iter: impl arrow2::trusted_len::TrustedLen>, ) -> Self { let arrow_array = Box::new(arrow2::array::Utf8Array::::from_trusted_len_iter(iter)); - DataArray::new( + Self::new( Field::new(name, crate::datatypes::DataType::Utf8).into(), arrow_array, ) @@ -37,7 +37,7 @@ impl BinaryArray { let arrow_array = Box::new(arrow2::array::BinaryArray::::from_trusted_len_iter( iter, )); - DataArray::new( + Self::new( Field::new(name, crate::datatypes::DataType::Binary).into(), arrow_array, ) @@ -52,7 +52,7 @@ impl FixedSizeBinaryArray { size: usize, ) -> Self { let arrow_array = Box::new(arrow2::array::FixedSizeBinaryArray::from_iter(iter, size)); - DataArray::new( + Self::new( Field::new(name, crate::datatypes::DataType::FixedSizeBinary(size)).into(), arrow_array, ) @@ -66,7 +66,7 @@ impl BooleanArray { iter: impl arrow2::trusted_len::TrustedLen>, ) -> Self { let arrow_array = Box::new(arrow2::array::BooleanArray::from_trusted_len_iter(iter)); - DataArray::new( + Self::new( Field::new(name, crate::datatypes::DataType::Boolean).into(), arrow_array, ) @@ -85,7 +85,7 @@ where let arrow_array = Box::new( arrow2::array::PrimitiveArray::::from_trusted_len_values_iter(iter), ); - DataArray::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() + Self::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() } } @@ -96,7 +96,7 @@ impl Utf8Array { ) -> Self { let arrow_array = Box::new(arrow2::array::Utf8Array::::from_trusted_len_values_iter(iter)); - DataArray::new(Field::new(name, DataType::Utf8).into(), arrow_array).unwrap() + Self::new(Field::new(name, DataType::Utf8).into(), arrow_array).unwrap() } } @@ -107,7 +107,7 @@ impl BinaryArray { ) -> Self { let arrow_array = Box::new(arrow2::array::BinaryArray::::from_trusted_len_values_iter(iter)); - DataArray::new(Field::new(name, DataType::Binary).into(), arrow_array).unwrap() + Self::new(Field::new(name, DataType::Binary).into(), arrow_array).unwrap() } } @@ -119,6 +119,6 @@ impl BooleanArray { let arrow_array = Box::new(arrow2::array::BooleanArray::from_trusted_len_values_iter( iter, )); - DataArray::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() + Self::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() } } diff --git a/src/daft-core/src/array/image_array.rs b/src/daft-core/src/array/image_array.rs index 205075efbf..5daa11d42d 100644 --- a/src/daft-core/src/array/image_array.rs +++ b/src/daft-core/src/array/image_array.rs @@ -19,7 +19,7 @@ impl BBox { .downcast_ref::() .unwrap() .iter(); - BBox( + Self( *iter.next().unwrap().unwrap(), *iter.next().unwrap().unwrap(), *iter.next().unwrap().unwrap(), diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 964503b271..538c24e716 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -53,7 +53,7 @@ impl ListArray { field ), } - ListArray { + Self { field, flat_child, offsets, @@ -102,7 +102,7 @@ impl ListArray { growable .build() - .map(|s| s.downcast::().unwrap().clone()) + .map(|s| s.downcast::().unwrap().clone()) } pub fn len(&self) -> usize { diff --git a/src/daft-core/src/array/mod.rs b/src/daft-core/src/array/mod.rs index 21a811b403..7c300c6a38 100644 --- a/src/daft-core/src/array/mod.rs +++ b/src/daft-core/src/array/mod.rs @@ -30,7 +30,7 @@ pub struct DataArray { impl Clone for DataArray { fn clone(&self) -> Self { - DataArray::new(self.field.clone(), self.data.clone()).unwrap() + Self::new(self.field.clone(), self.data.clone()).unwrap() } } @@ -44,7 +44,7 @@ impl DataArray where T: DaftPhysicalType, { - pub fn new(field: Arc, data: Box) -> DaftResult> { + pub fn new(field: Arc, data: Box) -> DaftResult { assert!( field.dtype.is_physical(), "Can only construct DataArray for PhysicalTypes, got {}", @@ -61,7 +61,7 @@ where } } - Ok(DataArray { + Ok(Self { field, data, marker_: PhantomData, @@ -93,7 +93,7 @@ where ))); } let with_bitmap = self.data.with_validity(Some(Bitmap::from(validity))); - DataArray::new(self.field.clone(), with_bitmap) + Self::new(self.field.clone(), with_bitmap) } pub fn with_validity(&self, validity: Option) -> DaftResult { @@ -107,7 +107,7 @@ where ))); } let with_bitmap = self.data.with_validity(validity); - DataArray::new(self.field.clone(), with_bitmap) + Self::new(self.field.clone(), with_bitmap) } pub fn validity(&self) -> Option<&Bitmap> { diff --git a/src/daft-core/src/array/ops/apply.rs b/src/daft-core/src/array/ops/apply.rs index 159232e602..f5388bfbc6 100644 --- a/src/daft-core/src/array/ops/apply.rs +++ b/src/daft-core/src/array/ops/apply.rs @@ -20,7 +20,7 @@ where PrimitiveArray::from_trusted_len_values_iter(arr.values_iter().map(|v| func(*v))) .with_validity(arr.validity().cloned()); - Ok(DataArray::from((self.name(), Box::new(result_arr)))) + Ok(Self::from((self.name(), Box::new(result_arr)))) } // applies a native binary function to two DataArrays, maintaining validity. @@ -44,17 +44,13 @@ where zip(lhs_arr.values_iter(), rhs_arr.values_iter()).map(|(a, b)| func(*a, *b)), ) .with_validity(validity); - Ok(DataArray::from((self.name(), Box::new(result_arr)))) + Ok(Self::from((self.name(), Box::new(result_arr)))) } (l_size, 1) => { if let Some(value) = rhs.get(0) { self.apply(|v| func(v, value)) } else { - Ok(DataArray::::full_null( - self.name(), - self.data_type(), - l_size, - )) + Ok(Self::full_null(self.name(), self.data_type(), l_size)) } } (1, r_size) => { @@ -65,13 +61,9 @@ where rhs_arr.values_iter().map(|v| func(value, *v)), ) .with_validity(rhs_arr.validity().cloned()); - Ok(DataArray::from((self.name(), Box::new(result_arr)))) + Ok(Self::from((self.name(), Box::new(result_arr)))) } else { - Ok(DataArray::::full_null( - self.name(), - self.data_type(), - r_size, - )) + Ok(Self::full_null(self.name(), self.data_type(), r_size)) } } (l, r) => Err(DaftError::ValueError(format!( diff --git a/src/daft-core/src/array/ops/approx_count_distinct.rs b/src/daft-core/src/array/ops/approx_count_distinct.rs index 068275d2ed..66f2008ed1 100644 --- a/src/daft-core/src/array/ops/approx_count_distinct.rs +++ b/src/daft-core/src/array/ops/approx_count_distinct.rs @@ -10,7 +10,7 @@ use crate::{ }; impl DaftApproxCountDistinctAggable for UInt64Array { - type Output = DaftResult; + type Output = DaftResult; fn approx_count_distinct(&self) -> Self::Output { let mut set = HashSet::with_capacity_and_hasher(self.len(), IdentityBuildHasher::default()); diff --git a/src/daft-core/src/array/ops/arange.rs b/src/daft-core/src/array/ops/arange.rs index 33da976928..a9c4e62b1d 100644 --- a/src/daft-core/src/array/ops/arange.rs +++ b/src/daft-core/src/array/ops/arange.rs @@ -19,7 +19,7 @@ where let arrow_array = Box::new(arrow2::array::PrimitiveArray::::from_vec(data)); let data_array = Int64Array::from((name.as_ref(), arrow_array)); let casted_array = data_array.cast(&T::get_dtype())?; - let downcasted = casted_array.downcast::>()?; + let downcasted = casted_array.downcast::()?; Ok(downcasted.clone()) } } diff --git a/src/daft-core/src/array/ops/between.rs b/src/daft-core/src/array/ops/between.rs index 90c22d2151..09e1914372 100644 --- a/src/daft-core/src/array/ops/between.rs +++ b/src/daft-core/src/array/ops/between.rs @@ -6,13 +6,13 @@ use crate::{ datatypes::{BooleanArray, DaftNumericType}, }; -impl DaftBetween<&DataArray, &DataArray> for DataArray +impl DaftBetween<&Self, &Self> for DataArray where T: DaftNumericType, { type Output = DaftResult; - fn between(&self, lower: &DataArray, upper: &DataArray) -> Self::Output { + fn between(&self, lower: &Self, upper: &Self) -> Self::Output { let are_two_equal_and_single_one = |v_size, l_size, u_size: usize| { [v_size, l_size, u_size] .iter() diff --git a/src/daft-core/src/array/ops/bitwise.rs b/src/daft-core/src/array/ops/bitwise.rs index 6a38eee06a..4d4630bbad 100644 --- a/src/daft-core/src/array/ops/bitwise.rs +++ b/src/daft-core/src/array/ops/bitwise.rs @@ -8,7 +8,7 @@ use crate::{ datatypes::{DaftIntegerType, DaftNumericType}, }; -impl DaftLogical<&DataArray> for DataArray +impl DaftLogical<&Self> for DataArray where T: DaftIntegerType, ::Native: @@ -16,15 +16,15 @@ where { type Output = DaftResult; - fn and(&self, rhs: &DataArray) -> Self::Output { + fn and(&self, rhs: &Self) -> Self::Output { self.binary_apply(rhs, |lhs, rhs| lhs.bitand(rhs)) } - fn or(&self, rhs: &DataArray) -> Self::Output { + fn or(&self, rhs: &Self) -> Self::Output { self.binary_apply(rhs, |lhs, rhs| lhs.bitor(rhs)) } - fn xor(&self, rhs: &DataArray) -> Self::Output { + fn xor(&self, rhs: &Self) -> Self::Output { self.binary_apply(rhs, |lhs, rhs| lhs.bitxor(rhs)) } } diff --git a/src/daft-core/src/array/ops/broadcast.rs b/src/daft-core/src/array/ops/broadcast.rs index b372ab2ef6..2a81dd5a5e 100644 --- a/src/daft-core/src/array/ops/broadcast.rs +++ b/src/daft-core/src/array/ops/broadcast.rs @@ -35,7 +35,7 @@ where impl Broadcastable for DataArray where T: DaftPhysicalType + 'static, - DataArray: GrowableArray, + Self: GrowableArray, { fn broadcast(&self, num: usize) -> DaftResult { if self.len() != 1 { @@ -48,7 +48,7 @@ where if self.is_valid(0) { generic_growable_broadcast(self, num, self.name(), self.data_type()) } else { - Ok(DataArray::full_null(self.name(), self.data_type(), num)) + Ok(Self::full_null(self.name(), self.data_type(), num)) } } } @@ -65,11 +65,7 @@ impl Broadcastable for FixedSizeListArray { if self.is_valid(0) { generic_growable_broadcast(self, num, self.name(), self.data_type()) } else { - Ok(FixedSizeListArray::full_null( - self.name(), - self.data_type(), - num, - )) + Ok(Self::full_null(self.name(), self.data_type(), num)) } } } @@ -86,7 +82,7 @@ impl Broadcastable for ListArray { if self.is_valid(0) { generic_growable_broadcast(self, num, self.name(), self.data_type()) } else { - Ok(ListArray::full_null(self.name(), self.data_type(), num)) + Ok(Self::full_null(self.name(), self.data_type(), num)) } } } @@ -103,7 +99,7 @@ impl Broadcastable for StructArray { if self.is_valid(0) { generic_growable_broadcast(self, num, self.name(), self.data_type()) } else { - Ok(StructArray::full_null(self.name(), self.data_type(), num)) + Ok(Self::full_null(self.name(), self.data_type(), num)) } } } diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 3f14b8f2f5..c3dbe0c209 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -1956,7 +1956,7 @@ impl FixedSizeListArray { ))); } let casted_child = self.flat_child.cast(child_dtype.as_ref())?; - Ok(FixedSizeListArray::new( + Ok(Self::new( Field::new(self.name().to_string(), dtype.clone()), casted_child, self.validity().cloned(), @@ -2018,7 +2018,7 @@ impl FixedSizeListArray { impl ListArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { match dtype { - DataType::List(child_dtype) => Ok(ListArray::new( + DataType::List(child_dtype) => Ok(Self::new( Field::new(self.name(), dtype.clone()), self.flat_child.cast(child_dtype.as_ref())?, self.offsets().clone(), @@ -2138,7 +2138,7 @@ impl StructArray { }, ) .collect::>>(); - Ok(StructArray::new( + Ok(Self::new( Field::new(self.name(), dtype.clone()), casted_series?, self.validity().cloned(), diff --git a/src/daft-core/src/array/ops/compare_agg.rs b/src/daft-core/src/array/ops/compare_agg.rs index 0fc139b36d..5d5237d9c1 100644 --- a/src/daft-core/src/array/ops/compare_agg.rs +++ b/src/daft-core/src/array/ops/compare_agg.rs @@ -64,7 +64,7 @@ where T::Native: PartialOrd, ::Simd: arrow2::compute::aggregate::SimdOrd, { - type Output = DaftResult>; + type Output = DaftResult; fn min(&self) -> Self::Output { let primitive_arr = self.as_arrow(); @@ -72,7 +72,7 @@ where let result = arrow2::compute::aggregate::min_primitive(primitive_arr); let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result])); - DataArray::new(self.field.clone(), arrow_array) + Self::new(self.field.clone(), arrow_array) } fn max(&self) -> Self::Output { @@ -81,7 +81,7 @@ where let result = arrow2::compute::aggregate::max_primitive(primitive_arr); let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result])); - DataArray::new(self.field.clone(), arrow_array) + Self::new(self.field.clone(), arrow_array) } fn grouped_min(&self, groups: &GroupIndices) -> Self::Output { grouped_cmp_native( @@ -157,14 +157,14 @@ where } impl DaftCompareAggable for DataArray { - type Output = DaftResult>; + type Output = DaftResult; fn min(&self) -> Self::Output { let arrow_array: &arrow2::array::Utf8Array = self.as_arrow(); let result = arrow2::compute::aggregate::min_string(arrow_array); let res_arrow_array = arrow2::array::Utf8Array::::from([result]); - DataArray::new(self.field.clone(), Box::new(res_arrow_array)) + Self::new(self.field.clone(), Box::new(res_arrow_array)) } fn max(&self) -> Self::Output { let arrow_array: &arrow2::array::Utf8Array = self.as_arrow(); @@ -172,7 +172,7 @@ impl DaftCompareAggable for DataArray { let result = arrow2::compute::aggregate::max_string(arrow_array); let res_arrow_array = arrow2::array::Utf8Array::::from([result]); - DataArray::new(self.field.clone(), Box::new(res_arrow_array)) + Self::new(self.field.clone(), Box::new(res_arrow_array)) } fn grouped_min(&self, groups: &GroupIndices) -> Self::Output { @@ -237,14 +237,14 @@ where } impl DaftCompareAggable for DataArray { - type Output = DaftResult>; + type Output = DaftResult; fn min(&self) -> Self::Output { let arrow_array: &arrow2::array::BinaryArray = self.as_arrow(); let result = arrow2::compute::aggregate::min_binary(arrow_array); let res_arrow_array = arrow2::array::BinaryArray::::from([result]); - DataArray::new(self.field.clone(), Box::new(res_arrow_array)) + Self::new(self.field.clone(), Box::new(res_arrow_array)) } fn max(&self) -> Self::Output { let arrow_array: &arrow2::array::BinaryArray = self.as_arrow(); @@ -252,7 +252,7 @@ impl DaftCompareAggable for DataArray { let result = arrow2::compute::aggregate::max_binary(arrow_array); let res_arrow_array = arrow2::array::BinaryArray::::from([result]); - DataArray::new(self.field.clone(), Box::new(res_arrow_array)) + Self::new(self.field.clone(), Box::new(res_arrow_array)) } fn grouped_min(&self, groups: &GroupIndices) -> Self::Output { @@ -354,7 +354,7 @@ where } impl DaftCompareAggable for DataArray { - type Output = DaftResult>; + type Output = DaftResult; fn min(&self) -> Self::Output { cmp_fixed_size_binary(self, |l, r| l.min(r)) } @@ -423,14 +423,14 @@ fn grouped_cmp_bool( } impl DaftCompareAggable for DataArray { - type Output = DaftResult>; + type Output = DaftResult; fn min(&self) -> Self::Output { let arrow_array: &arrow2::array::BooleanArray = self.as_arrow(); let result = arrow2::compute::aggregate::min_boolean(arrow_array); let res_arrow_array = arrow2::array::BooleanArray::from([result]); - DataArray::new(self.field.clone(), Box::new(res_arrow_array)) + Self::new(self.field.clone(), Box::new(res_arrow_array)) } fn max(&self) -> Self::Output { let arrow_array: &arrow2::array::BooleanArray = self.as_arrow(); @@ -438,7 +438,7 @@ impl DaftCompareAggable for DataArray { let result = arrow2::compute::aggregate::max_boolean(arrow_array); let res_arrow_array = arrow2::array::BooleanArray::from([result]); - DataArray::new(self.field.clone(), Box::new(res_arrow_array)) + Self::new(self.field.clone(), Box::new(res_arrow_array)) } fn grouped_min(&self, groups: &GroupIndices) -> Self::Output { @@ -451,11 +451,11 @@ impl DaftCompareAggable for DataArray { } impl DaftCompareAggable for DataArray { - type Output = DaftResult>; + type Output = DaftResult; fn min(&self) -> Self::Output { let res_arrow_array = arrow2::array::NullArray::new(arrow2::datatypes::DataType::Null, 1); - DataArray::new(self.field.clone(), Box::new(res_arrow_array)) + Self::new(self.field.clone(), Box::new(res_arrow_array)) } fn max(&self) -> Self::Output { @@ -464,19 +464,11 @@ impl DaftCompareAggable for DataArray { } fn grouped_min(&self, groups: &super::GroupIndices) -> Self::Output { - Ok(DataArray::full_null( - self.name(), - self.data_type(), - groups.len(), - )) + Ok(Self::full_null(self.name(), self.data_type(), groups.len())) } fn grouped_max(&self, groups: &super::GroupIndices) -> Self::Output { - Ok(DataArray::full_null( - self.name(), - self.data_type(), - groups.len(), - )) + Ok(Self::full_null(self.name(), self.data_type(), groups.len())) } } diff --git a/src/daft-core/src/array/ops/comparison.rs b/src/daft-core/src/array/ops/comparison.rs index 0f76b338ff..aee84893de 100644 --- a/src/daft-core/src/array/ops/comparison.rs +++ b/src/daft-core/src/array/ops/comparison.rs @@ -23,13 +23,13 @@ where } } -impl DaftCompare<&DataArray> for DataArray +impl DaftCompare<&Self> for DataArray where T: DaftNumericType, { type Output = DaftResult; - fn equal(&self, rhs: &DataArray) -> Self::Output { + fn equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -69,7 +69,7 @@ where } } - fn not_equal(&self, rhs: &DataArray) -> Self::Output { + fn not_equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -109,7 +109,7 @@ where } } - fn lt(&self, rhs: &DataArray) -> Self::Output { + fn lt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -149,7 +149,7 @@ where } } - fn lte(&self, rhs: &DataArray) -> Self::Output { + fn lte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -189,7 +189,7 @@ where } } - fn gt(&self, rhs: &DataArray) -> Self::Output { + fn gt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -229,7 +229,7 @@ where } } - fn gte(&self, rhs: &DataArray) -> Self::Output { + fn gte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -337,15 +337,15 @@ where } } -impl DaftCompare<&BooleanArray> for BooleanArray { - type Output = DaftResult; +impl DaftCompare<&Self> for BooleanArray { + type Output = DaftResult; - fn equal(&self, rhs: &BooleanArray) -> Self::Output { + fn equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = arrow_bitmap_and_helper(self.as_arrow().validity(), rhs.as_arrow().validity()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), comparison::eq(self.as_arrow(), rhs.as_arrow()).with_validity(validity), ))) @@ -354,22 +354,14 @@ impl DaftCompare<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.equal(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.equal(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -380,12 +372,12 @@ impl DaftCompare<&BooleanArray> for BooleanArray { } } - fn not_equal(&self, rhs: &BooleanArray) -> Self::Output { + fn not_equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = arrow_bitmap_and_helper(self.as_arrow().validity(), rhs.as_arrow().validity()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), comparison::neq(self.as_arrow(), rhs.as_arrow()).with_validity(validity), ))) @@ -394,22 +386,14 @@ impl DaftCompare<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.not_equal(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.not_equal(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -420,12 +404,12 @@ impl DaftCompare<&BooleanArray> for BooleanArray { } } - fn lt(&self, rhs: &BooleanArray) -> Self::Output { + fn lt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = arrow_bitmap_and_helper(self.as_arrow().validity(), rhs.as_arrow().validity()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), comparison::lt(self.as_arrow(), rhs.as_arrow()).with_validity(validity), ))) @@ -434,22 +418,14 @@ impl DaftCompare<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.lt(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.gt(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -460,12 +436,12 @@ impl DaftCompare<&BooleanArray> for BooleanArray { } } - fn lte(&self, rhs: &BooleanArray) -> Self::Output { + fn lte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = arrow_bitmap_and_helper(self.as_arrow().validity(), rhs.as_arrow().validity()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), comparison::lt_eq(self.as_arrow(), rhs.as_arrow()).with_validity(validity), ))) @@ -474,22 +450,14 @@ impl DaftCompare<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.lte(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.gte(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -500,12 +468,12 @@ impl DaftCompare<&BooleanArray> for BooleanArray { } } - fn gt(&self, rhs: &BooleanArray) -> Self::Output { + fn gt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = arrow_bitmap_and_helper(self.as_arrow().validity(), rhs.as_arrow().validity()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), comparison::gt(self.as_arrow(), rhs.as_arrow()).with_validity(validity), ))) @@ -514,22 +482,14 @@ impl DaftCompare<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.gt(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.lt(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -540,12 +500,12 @@ impl DaftCompare<&BooleanArray> for BooleanArray { } } - fn gte(&self, rhs: &BooleanArray) -> Self::Output { + fn gte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = arrow_bitmap_and_helper(self.as_arrow().validity(), rhs.as_arrow().validity()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), comparison::gt_eq(self.as_arrow(), rhs.as_arrow()).with_validity(validity), ))) @@ -554,22 +514,14 @@ impl DaftCompare<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.gte(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.lte(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -582,14 +534,14 @@ impl DaftCompare<&BooleanArray> for BooleanArray { } impl DaftCompare for BooleanArray { - type Output = DaftResult; + type Output = DaftResult; fn equal(&self, rhs: bool) -> Self::Output { let validity = self.as_arrow().validity().cloned(); let arrow_result = comparison::boolean::eq_scalar(self.as_arrow(), rhs).with_validity(validity); - Ok(BooleanArray::from((self.name(), arrow_result))) + Ok(Self::from((self.name(), arrow_result))) } fn not_equal(&self, rhs: bool) -> Self::Output { @@ -597,7 +549,7 @@ impl DaftCompare for BooleanArray { let arrow_result = comparison::boolean::neq_scalar(self.as_arrow(), rhs).with_validity(validity); - Ok(BooleanArray::from((self.name(), arrow_result))) + Ok(Self::from((self.name(), arrow_result))) } fn lt(&self, rhs: bool) -> Self::Output { @@ -605,7 +557,7 @@ impl DaftCompare for BooleanArray { let arrow_result = comparison::boolean::lt_scalar(self.as_arrow(), rhs).with_validity(validity); - Ok(BooleanArray::from((self.name(), arrow_result))) + Ok(Self::from((self.name(), arrow_result))) } fn lte(&self, rhs: bool) -> Self::Output { @@ -613,7 +565,7 @@ impl DaftCompare for BooleanArray { let arrow_result = comparison::boolean::lt_eq_scalar(self.as_arrow(), rhs).with_validity(validity); - Ok(BooleanArray::from((self.name(), arrow_result))) + Ok(Self::from((self.name(), arrow_result))) } fn gt(&self, rhs: bool) -> Self::Output { @@ -621,7 +573,7 @@ impl DaftCompare for BooleanArray { let arrow_result = comparison::boolean::gt_scalar(self.as_arrow(), rhs).with_validity(validity); - Ok(BooleanArray::from((self.name(), arrow_result))) + Ok(Self::from((self.name(), arrow_result))) } fn gte(&self, rhs: bool) -> Self::Output { @@ -629,7 +581,7 @@ impl DaftCompare for BooleanArray { let arrow_result = comparison::boolean::gt_eq_scalar(self.as_arrow(), rhs).with_validity(validity); - Ok(BooleanArray::from((self.name(), arrow_result))) + Ok(Self::from((self.name(), arrow_result))) } } @@ -646,9 +598,9 @@ impl Not for &BooleanArray { } } -impl DaftLogical<&BooleanArray> for BooleanArray { - type Output = DaftResult; - fn and(&self, rhs: &BooleanArray) -> Self::Output { +impl DaftLogical<&Self> for BooleanArray { + type Output = DaftResult; + fn and(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -656,7 +608,7 @@ impl DaftLogical<&BooleanArray> for BooleanArray { let result_bitmap = arrow2::bitmap::and(self.as_arrow().values(), rhs.as_arrow().values()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), arrow2::array::BooleanArray::new( arrow2::datatypes::DataType::Boolean, @@ -669,22 +621,14 @@ impl DaftLogical<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.and(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.and(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -695,7 +639,7 @@ impl DaftLogical<&BooleanArray> for BooleanArray { } } - fn or(&self, rhs: &BooleanArray) -> Self::Output { + fn or(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -703,7 +647,7 @@ impl DaftLogical<&BooleanArray> for BooleanArray { let result_bitmap = arrow2::bitmap::or(self.as_arrow().values(), rhs.as_arrow().values()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), arrow2::array::BooleanArray::new( arrow2::datatypes::DataType::Boolean, @@ -716,22 +660,14 @@ impl DaftLogical<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.or(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.or(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -742,7 +678,7 @@ impl DaftLogical<&BooleanArray> for BooleanArray { } } - fn xor(&self, rhs: &BooleanArray) -> Self::Output { + fn xor(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -750,7 +686,7 @@ impl DaftLogical<&BooleanArray> for BooleanArray { let result_bitmap = arrow2::bitmap::xor(self.as_arrow().values(), rhs.as_arrow().values()); - Ok(BooleanArray::from(( + Ok(Self::from(( self.name(), arrow2::array::BooleanArray::new( arrow2::datatypes::DataType::Boolean, @@ -763,22 +699,14 @@ impl DaftLogical<&BooleanArray> for BooleanArray { if let Some(value) = rhs.get(0) { self.xor(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - l_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, l_size)) } } (1, r_size) => { if let Some(value) = self.get(0) { rhs.xor(value) } else { - Ok(BooleanArray::full_null( - self.name(), - &DataType::Boolean, - r_size, - )) + Ok(Self::full_null(self.name(), &DataType::Boolean, r_size)) } } (l, r) => Err(DaftError::ValueError(format!( @@ -815,7 +743,7 @@ macro_rules! null_array_comparison_method { }; } -impl DaftCompare<&NullArray> for NullArray { +impl DaftCompare<&Self> for NullArray { type Output = DaftResult; null_array_comparison_method!(equal); null_array_comparison_method!(not_equal); @@ -826,7 +754,7 @@ impl DaftCompare<&NullArray> for NullArray { } impl DaftLogical for BooleanArray { - type Output = DaftResult; + type Output = DaftResult; fn and(&self, rhs: bool) -> Self::Output { let validity = self.as_arrow().validity(); if rhs { @@ -838,7 +766,7 @@ impl DaftLogical for BooleanArray { Bitmap::new_zeroed(self.len()), validity.cloned(), ); - return Ok(BooleanArray::from((self.name(), arrow_array))); + return Ok(Self::from((self.name(), arrow_array))); } } @@ -851,7 +779,7 @@ impl DaftLogical for BooleanArray { Bitmap::new_zeroed(self.len()).not(), validity.cloned(), ); - return Ok(BooleanArray::from((self.name(), arrow_array))); + return Ok(Self::from((self.name(), arrow_array))); } else { Ok(self.clone()) } @@ -866,10 +794,10 @@ impl DaftLogical for BooleanArray { } } -impl DaftCompare<&Utf8Array> for Utf8Array { +impl DaftCompare<&Self> for Utf8Array { type Output = DaftResult; - fn equal(&self, rhs: &Utf8Array) -> Self::Output { + fn equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -909,7 +837,7 @@ impl DaftCompare<&Utf8Array> for Utf8Array { } } - fn not_equal(&self, rhs: &Utf8Array) -> Self::Output { + fn not_equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -949,7 +877,7 @@ impl DaftCompare<&Utf8Array> for Utf8Array { } } - fn lt(&self, rhs: &Utf8Array) -> Self::Output { + fn lt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -989,7 +917,7 @@ impl DaftCompare<&Utf8Array> for Utf8Array { } } - fn lte(&self, rhs: &Utf8Array) -> Self::Output { + fn lte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1029,7 +957,7 @@ impl DaftCompare<&Utf8Array> for Utf8Array { } } - fn gt(&self, rhs: &Utf8Array) -> Self::Output { + fn gt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1069,7 +997,7 @@ impl DaftCompare<&Utf8Array> for Utf8Array { } } - fn gte(&self, rhs: &Utf8Array) -> Self::Output { + fn gte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1162,10 +1090,10 @@ impl DaftCompare<&str> for Utf8Array { } } -impl DaftCompare<&BinaryArray> for BinaryArray { +impl DaftCompare<&Self> for BinaryArray { type Output = DaftResult; - fn equal(&self, rhs: &BinaryArray) -> Self::Output { + fn equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1205,7 +1133,7 @@ impl DaftCompare<&BinaryArray> for BinaryArray { } } - fn not_equal(&self, rhs: &BinaryArray) -> Self::Output { + fn not_equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1245,7 +1173,7 @@ impl DaftCompare<&BinaryArray> for BinaryArray { } } - fn lt(&self, rhs: &BinaryArray) -> Self::Output { + fn lt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1285,7 +1213,7 @@ impl DaftCompare<&BinaryArray> for BinaryArray { } } - fn lte(&self, rhs: &BinaryArray) -> Self::Output { + fn lte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1325,7 +1253,7 @@ impl DaftCompare<&BinaryArray> for BinaryArray { } } - fn gt(&self, rhs: &BinaryArray) -> Self::Output { + fn gt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1365,7 +1293,7 @@ impl DaftCompare<&BinaryArray> for BinaryArray { } } - fn gte(&self, rhs: &BinaryArray) -> Self::Output { + fn gte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => { let validity = @@ -1515,10 +1443,10 @@ where ) } -impl DaftCompare<&FixedSizeBinaryArray> for FixedSizeBinaryArray { +impl DaftCompare<&Self> for FixedSizeBinaryArray { type Output = DaftResult; - fn equal(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + fn equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs == rhs), (l_size, 1) => { @@ -1551,7 +1479,7 @@ impl DaftCompare<&FixedSizeBinaryArray> for FixedSizeBinaryArray { } } - fn not_equal(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + fn not_equal(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs != rhs), (l_size, 1) => { @@ -1584,7 +1512,7 @@ impl DaftCompare<&FixedSizeBinaryArray> for FixedSizeBinaryArray { } } - fn lt(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + fn lt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs < rhs), (l_size, 1) => { @@ -1617,7 +1545,7 @@ impl DaftCompare<&FixedSizeBinaryArray> for FixedSizeBinaryArray { } } - fn lte(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + fn lte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs <= rhs), (l_size, 1) => { @@ -1650,7 +1578,7 @@ impl DaftCompare<&FixedSizeBinaryArray> for FixedSizeBinaryArray { } } - fn gt(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + fn gt(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs > rhs), (l_size, 1) => { @@ -1683,7 +1611,7 @@ impl DaftCompare<&FixedSizeBinaryArray> for FixedSizeBinaryArray { } } - fn gte(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + fn gte(&self, rhs: &Self) -> Self::Output { match (self.len(), rhs.len()) { (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs >= rhs), (l_size, 1) => { diff --git a/src/daft-core/src/array/ops/concat.rs b/src/daft-core/src/array/ops/concat.rs index b83f2b150b..3424b46811 100644 --- a/src/daft-core/src/array/ops/concat.rs +++ b/src/daft-core/src/array/ops/concat.rs @@ -94,20 +94,20 @@ where }) .collect(), )); - DataArray::new(field.clone(), cat_array) + Self::new(field.clone(), cat_array) } crate::datatypes::DataType::Utf8 => { let cat_array = utf8_concat(arrow_arrays.as_slice())?; - DataArray::new(field.clone(), cat_array) + Self::new(field.clone(), cat_array) } crate::datatypes::DataType::Binary => { let cat_array = binary_concat(arrow_arrays.as_slice())?; - DataArray::new(field.clone(), cat_array) + Self::new(field.clone(), cat_array) } _ => { let cat_array: Box = arrow2::compute::concatenate::concatenate(arrow_arrays.as_slice())?; - DataArray::try_from((field.clone(), cat_array)) + Self::try_from((field.clone(), cat_array)) } } } diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index 09ebb0876e..d3681ea3a5 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -69,7 +69,7 @@ impl DaftConcatAggable for ListArray { fn concat(&self) -> Self::Output { if self.null_count() == 0 { let new_offsets = OffsetsBuffer::::try_from(vec![0, *self.offsets().last()])?; - return Ok(ListArray::new( + return Ok(Self::new( self.field.clone(), self.flat_child.clone(), new_offsets, @@ -102,7 +102,7 @@ impl DaftConcatAggable for ListArray { let new_child = child_growable.build()?; let new_offsets = OffsetsBuffer::::try_from(vec![0, new_child.len() as i64])?; - Ok(ListArray::new( + Ok(Self::new( self.field.clone(), new_child, new_offsets, @@ -145,7 +145,7 @@ impl DaftConcatAggable for ListArray { Some(arrow2::bitmap::Bitmap::from(group_valids)) }; - Ok(ListArray::new( + Ok(Self::new( self.field.clone(), child_array_growable.build()?, new_offsets.into(), @@ -175,7 +175,7 @@ impl DaftConcatAggable for DataArray { ); let result_box = Box::new(output); - DataArray::new(self.field().clone().into(), result_box) + Self::new(self.field().clone().into(), result_box) } fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output { @@ -208,10 +208,7 @@ impl DaftConcatAggable for DataArray { ))) }; - Ok(DataArray::from(( - self.field.name.as_ref(), - concat_per_group, - ))) + Ok(Self::from((self.field.name.as_ref(), concat_per_group))) } } diff --git a/src/daft-core/src/array/ops/filter.rs b/src/daft-core/src/array/ops/filter.rs index e255b10119..f17740a549 100644 --- a/src/daft-core/src/array/ops/filter.rs +++ b/src/daft-core/src/array/ops/filter.rs @@ -28,7 +28,7 @@ impl crate::datatypes::PythonArray { use arrow2::array::Array; use pyo3::PyObject; - use crate::{array::pseudo_arrow::PseudoArrowArray, datatypes::PythonType}; + use crate::array::pseudo_arrow::PseudoArrowArray; let mask = mask.as_arrow(); @@ -71,7 +71,7 @@ impl crate::datatypes::PythonArray { let arrow_array: Box = Box::new(PseudoArrowArray::new(new_values.into(), new_validity)); - DataArray::::new(self.field().clone().into(), arrow_array) + Self::new(self.field().clone().into(), arrow_array) } } diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index a635fe6e21..1739b524a9 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -21,7 +21,7 @@ where impl FromArrow for DataArray { fn from_arrow(field: FieldRef, arrow_arr: Box) -> DaftResult { - DataArray::::try_from((field.clone(), arrow_arr)) + Self::try_from((field.clone(), arrow_arr)) } } @@ -42,7 +42,7 @@ where data_array_field, physical_arrow_arr, )?; - Ok(LogicalArray::::new(field.clone(), physical)) + Ok(Self::new(field.clone(), physical)) } } @@ -57,7 +57,7 @@ impl FromArrow for FixedSizeListArray { let arrow_arr = arrow_arr.as_ref().as_any().downcast_ref::().unwrap(); let arrow_child_array = arrow_arr.values(); let child_series = Series::from_arrow(Arc::new(Field::new("item", daft_child_dtype.as_ref().clone())), arrow_child_array.clone())?; - Ok(FixedSizeListArray::new( + Ok(Self::new( field.clone(), child_series, arrow_arr.validity().cloned(), @@ -91,7 +91,7 @@ impl FromArrow for ListArray { Arc::new(Field::new("list", daft_child_dtype.as_ref().clone())), arrow_child_array.clone(), )?; - Ok(ListArray::new( + Ok(Self::new( field.clone(), child_series, arrow_arr.offsets().clone(), @@ -108,7 +108,7 @@ impl FromArrow for ListArray { Arc::new(Field::new("map", daft_child_dtype.as_ref().clone())), arrow_child_array.clone(), )?; - Ok(ListArray::new( + Ok(Self::new( field.clone(), child_series, map_arr.offsets().into(), @@ -138,7 +138,7 @@ impl FromArrow for StructArray { Series::from_arrow(Arc::new(daft_field.clone()), arrow_arr.to_boxed()) }).collect::>>()?; - Ok(StructArray::new( + Ok(Self::new( field.clone(), child_series, arrow_arr.validity().cloned(), diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index 9950116cce..ac65be6a7a 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -36,7 +36,7 @@ where if dtype.is_python() { let py_none = Python::with_gil(|py: Python| py.None()); - return DataArray::new( + return Self::new( field.into(), Box::new(PseudoArrowArray::from_pyobj_vec(vec![py_none; length])), ) @@ -45,7 +45,7 @@ where let arrow_dtype = dtype.to_arrow(); match arrow_dtype { - Ok(arrow_dtype) => DataArray::::new( + Ok(arrow_dtype) => Self::new( Arc::new(Field::new(name.to_string(), dtype.clone())), arrow2::array::new_null_array(arrow_dtype, length), ) @@ -58,7 +58,7 @@ where let field = Field::new(name, dtype.clone()); #[cfg(feature = "python")] if dtype.is_python() { - return DataArray::new( + return Self::new( field.into(), Box::new(PseudoArrowArray::from_pyobj_vec(vec![])), ) @@ -67,7 +67,7 @@ where let arrow_dtype = dtype.to_arrow(); match arrow_dtype { - Ok(arrow_dtype) => DataArray::::new( + Ok(arrow_dtype) => Self::new( Arc::new(Field::new(name.to_string(), dtype.clone())), arrow2::array::new_empty_array(arrow_dtype), ) diff --git a/src/daft-core/src/array/ops/if_else.rs b/src/daft-core/src/array/ops/if_else.rs index b92db36528..8981ac2e1f 100644 --- a/src/daft-core/src/array/ops/if_else.rs +++ b/src/daft-core/src/array/ops/if_else.rs @@ -115,13 +115,9 @@ fn generic_if_else( impl DataArray where T: DaftPhysicalType, - DataArray: GrowableArray + IntoSeries, + Self: GrowableArray + IntoSeries, { - pub fn if_else( - &self, - other: &DataArray, - predicate: &BooleanArray, - ) -> DaftResult> { + pub fn if_else(&self, other: &Self, predicate: &BooleanArray) -> DaftResult { generic_if_else( predicate, self.name(), @@ -131,7 +127,7 @@ where self.len(), other.len(), )? - .downcast::>() + .downcast::() .cloned() } } diff --git a/src/daft-core/src/array/ops/is_in.rs b/src/daft-core/src/array/ops/is_in.rs index 304ec2df4b..24e78e8f29 100644 --- a/src/daft-core/src/array/ops/is_in.rs +++ b/src/daft-core/src/array/ops/is_in.rs @@ -24,7 +24,7 @@ macro_rules! collect_to_set_and_check_membership { }}; } -impl DaftIsIn<&DataArray> for DataArray +impl DaftIsIn<&Self> for DataArray where T: DaftIntegerType, ::Native: Ord, @@ -33,7 +33,7 @@ where { type Output = DaftResult; - fn is_in(&self, rhs: &DataArray) -> Self::Output { + fn is_in(&self, rhs: &Self) -> Self::Output { collect_to_set_and_check_membership!(self, rhs) } } @@ -76,10 +76,10 @@ impl_is_in_non_numeric_array!(Utf8Array); impl_is_in_non_numeric_array!(BinaryArray); impl_is_in_non_numeric_array!(FixedSizeBinaryArray); -impl DaftIsIn<&NullArray> for NullArray { +impl DaftIsIn<&Self> for NullArray { type Output = DaftResult; - fn is_in(&self, _rhs: &NullArray) -> Self::Output { + fn is_in(&self, _rhs: &Self) -> Self::Output { // If self and rhs are null array then return a full null array Ok(BooleanArray::full_null( self.name(), diff --git a/src/daft-core/src/array/ops/list_agg.rs b/src/daft-core/src/array/ops/list_agg.rs index 6e47d011ac..0792a17675 100644 --- a/src/daft-core/src/array/ops/list_agg.rs +++ b/src/daft-core/src/array/ops/list_agg.rs @@ -13,8 +13,8 @@ use crate::{ impl DaftListAggable for DataArray where T: DaftArrowBackedType, - DataArray: IntoSeries, - DataArray: GrowableArray, + Self: IntoSeries, + Self: GrowableArray, { type Output = DaftResult; fn list(&self) -> Self::Output { @@ -60,7 +60,7 @@ where #[cfg(feature = "python")] impl DaftListAggable for crate::datatypes::PythonArray { - type Output = DaftResult; + type Output = DaftResult; fn list(&self) -> Self::Output { use pyo3::{prelude::*, types::PyList}; @@ -97,7 +97,7 @@ impl DaftListAggable for crate::datatypes::PythonArray { } impl DaftListAggable for ListArray { - type Output = DaftResult; + type Output = DaftResult; fn list(&self) -> Self::Output { // TODO(FixedSizeList) diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index d28c920419..ba2d791101 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -148,7 +148,7 @@ where None, ); - Ok(DataArray::::from((self.name(), Box::new(result)))) + Ok(Self::from((self.name(), Box::new(result)))) } } @@ -240,7 +240,7 @@ impl Float32Array { None, ); - Ok(Float32Array::from((self.name(), Box::new(result)))) + Ok(Self::from((self.name(), Box::new(result)))) } } @@ -332,7 +332,7 @@ impl Float64Array { None, ); - Ok(Float64Array::from((self.name(), Box::new(result)))) + Ok(Self::from((self.name(), Box::new(result)))) } } @@ -462,7 +462,7 @@ impl BooleanArray { let result = arrow2::compute::sort::sort(self.data(), &options, None)?; - BooleanArray::try_from((self.field.clone(), result)) + Self::try_from((self.field.clone(), result)) } } diff --git a/src/daft-core/src/array/ops/take.rs b/src/daft-core/src/array/ops/take.rs index 301a311594..fbd42e8615 100644 --- a/src/daft-core/src/array/ops/take.rs +++ b/src/daft-core/src/array/ops/take.rs @@ -80,7 +80,7 @@ impl FixedSizeBinaryArray { I: DaftIntegerType, ::Native: arrow2::types::Index, { - let mut growable = FixedSizeBinaryArray::make_growable( + let mut growable = Self::make_growable( self.name(), self.data_type(), vec![self], @@ -99,10 +99,7 @@ impl FixedSizeBinaryArray { } } - Ok(growable - .build()? - .downcast::()? - .clone()) + Ok(growable.build()?.downcast::()?.clone()) } } @@ -116,7 +113,7 @@ impl crate::datatypes::PythonArray { use arrow2::array::Array; use pyo3::prelude::*; - use crate::{array::pseudo_arrow::PseudoArrowArray, datatypes::PythonType}; + use crate::array::pseudo_arrow::PseudoArrowArray; let indices = idx.as_arrow(); @@ -165,7 +162,7 @@ impl crate::datatypes::PythonArray { let arrow_array: Box = Box::new(PseudoArrowArray::new(new_values.into(), new_validity)); - DataArray::::new(self.field().clone().into(), arrow_array) + Self::new(self.field().clone().into(), arrow_array) } } @@ -175,7 +172,7 @@ impl FixedSizeListArray { I: DaftIntegerType, ::Native: arrow2::types::Index, { - let mut growable = FixedSizeListArray::make_growable( + let mut growable = Self::make_growable( self.name(), self.data_type(), vec![self], @@ -194,7 +191,7 @@ impl FixedSizeListArray { } } - Ok(growable.build()?.downcast::()?.clone()) + Ok(growable.build()?.downcast::()?.clone()) } } @@ -215,7 +212,7 @@ impl ListArray { } }) .sum(); - let mut growable = ::GrowableType::new( + let mut growable = ::GrowableType::new( self.name(), self.data_type(), vec![self], @@ -235,7 +232,7 @@ impl ListArray { } } - Ok(growable.build()?.downcast::()?.clone()) + Ok(growable.build()?.downcast::()?.clone()) } } diff --git a/src/daft-core/src/array/ops/trigonometry.rs b/src/daft-core/src/array/ops/trigonometry.rs index 659e11e39e..673b2e1a38 100644 --- a/src/daft-core/src/array/ops/trigonometry.rs +++ b/src/daft-core/src/array/ops/trigonometry.rs @@ -27,18 +27,18 @@ pub enum TrigonometricFunction { impl TrigonometricFunction { pub fn fn_name(&self) -> &'static str { match self { - TrigonometricFunction::Sin => "sin", - TrigonometricFunction::Cos => "cos", - TrigonometricFunction::Tan => "tan", - TrigonometricFunction::Cot => "cot", - TrigonometricFunction::ArcSin => "arcsin", - TrigonometricFunction::ArcCos => "arccos", - TrigonometricFunction::ArcTan => "arctan", - TrigonometricFunction::Radians => "radians", - TrigonometricFunction::Degrees => "degrees", - TrigonometricFunction::ArcTanh => "arctanh", - TrigonometricFunction::ArcCosh => "arccosh", - TrigonometricFunction::ArcSinh => "arcsinh", + Self::Sin => "sin", + Self::Cos => "cos", + Self::Tan => "tan", + Self::Cot => "cot", + Self::ArcSin => "arcsin", + Self::ArcCos => "arccos", + Self::ArcTan => "arctan", + Self::Radians => "radians", + Self::Degrees => "degrees", + Self::ArcTanh => "arctanh", + Self::ArcCosh => "arccosh", + Self::ArcSinh => "arcsinh", } } } diff --git a/src/daft-core/src/array/ops/truncate.rs b/src/daft-core/src/array/ops/truncate.rs index 83a5c1b0b4..c939cd89ac 100644 --- a/src/daft-core/src/array/ops/truncate.rs +++ b/src/daft-core/src/array/ops/truncate.rs @@ -44,7 +44,7 @@ impl_int_truncate!(UInt32Type); impl_int_truncate!(UInt64Type); impl Decimal128Array { - pub fn iceberg_truncate(&self, w: i64) -> DaftResult { + pub fn iceberg_truncate(&self, w: i64) -> DaftResult { let as_arrow = self.as_arrow(); let trun_value = as_arrow.into_iter().map(|v| { v.map(|i| { @@ -62,17 +62,17 @@ impl Decimal128Array { } impl Utf8Array { - pub fn iceberg_truncate(&self, w: i64) -> DaftResult { + pub fn iceberg_truncate(&self, w: i64) -> DaftResult { let as_arrow = self.as_arrow(); let substring = arrow2::compute::substring::utf8_substring(as_arrow, 0, &Some(w)); - Ok(Utf8Array::from((self.name(), Box::new(substring)))) + Ok(Self::from((self.name(), Box::new(substring)))) } } impl BinaryArray { - pub fn iceberg_truncate(&self, w: i64) -> DaftResult { + pub fn iceberg_truncate(&self, w: i64) -> DaftResult { let as_arrow = self.as_arrow(); let substring = arrow2::compute::substring::binary_substring(as_arrow, 0, &Some(w)); - Ok(BinaryArray::from((self.name(), Box::new(substring)))) + Ok(Self::from((self.name(), Box::new(substring)))) } } diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index a8e80623c1..ebac895e20 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -351,7 +351,7 @@ pub struct Utf8NormalizeOptions { } impl Utf8Array { - pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult { + pub fn endswith(&self, pattern: &Self) -> DaftResult { self.binary_broadcasted_compare( pattern, |data: &str, pat: &str| Ok(data.ends_with(pat)), @@ -359,7 +359,7 @@ impl Utf8Array { ) } - pub fn startswith(&self, pattern: &Utf8Array) -> DaftResult { + pub fn startswith(&self, pattern: &Self) -> DaftResult { self.binary_broadcasted_compare( pattern, |data: &str, pat: &str| Ok(data.starts_with(pat)), @@ -367,7 +367,7 @@ impl Utf8Array { ) } - pub fn contains(&self, pattern: &Utf8Array) -> DaftResult { + pub fn contains(&self, pattern: &Self) -> DaftResult { self.binary_broadcasted_compare( pattern, |data: &str, pat: &str| Ok(data.contains(pat)), @@ -375,7 +375,7 @@ impl Utf8Array { ) } - pub fn match_(&self, pattern: &Utf8Array) -> DaftResult { + pub fn match_(&self, pattern: &Self) -> DaftResult { if pattern.len() == 1 { let pattern_scalar_value = pattern.get(0); return match pattern_scalar_value { @@ -403,7 +403,7 @@ impl Utf8Array { ) } - pub fn split(&self, pattern: &Utf8Array, regex: bool) -> DaftResult { + pub fn split(&self, pattern: &Self, regex: bool) -> DaftResult { let (is_full_null, expected_size) = parse_inputs(self, &[pattern]) .map_err(|e| DaftError::ValueError(format!("Error in split: {e}")))?; if is_full_null { @@ -483,18 +483,14 @@ impl Utf8Array { Ok(result) } - pub fn extract(&self, pattern: &Utf8Array, index: usize) -> DaftResult { + pub fn extract(&self, pattern: &Self, index: usize) -> DaftResult { let (is_full_null, expected_size) = parse_inputs(self, &[pattern]) .map_err(|e| DaftError::ValueError(format!("Error in extract: {e}")))?; if is_full_null { - return Ok(Utf8Array::full_null( - self.name(), - &DataType::Utf8, - expected_size, - )); + return Ok(Self::full_null(self.name(), &DataType::Utf8, expected_size)); } if expected_size == 0 { - return Ok(Utf8Array::empty(self.name(), &DataType::Utf8)); + return Ok(Self::empty(self.name(), &DataType::Utf8)); } let self_iter = create_broadcasted_str_iter(self, expected_size); @@ -516,7 +512,7 @@ impl Utf8Array { Ok(result) } - pub fn extract_all(&self, pattern: &Utf8Array, index: usize) -> DaftResult { + pub fn extract_all(&self, pattern: &Self, index: usize) -> DaftResult { let (is_full_null, expected_size) = parse_inputs(self, &[pattern]) .map_err(|e| DaftError::ValueError(format!("Error in extract_all: {e}")))?; if is_full_null { @@ -552,23 +548,14 @@ impl Utf8Array { Ok(result) } - pub fn replace( - &self, - pattern: &Utf8Array, - replacement: &Utf8Array, - regex: bool, - ) -> DaftResult { + pub fn replace(&self, pattern: &Self, replacement: &Self, regex: bool) -> DaftResult { let (is_full_null, expected_size) = parse_inputs(self, &[pattern, replacement]) .map_err(|e| DaftError::ValueError(format!("Error in replace: {e}")))?; if is_full_null { - return Ok(Utf8Array::full_null( - self.name(), - &DataType::Utf8, - expected_size, - )); + return Ok(Self::full_null(self.name(), &DataType::Utf8, expected_size)); } if expected_size == 0 { - return Ok(Utf8Array::empty(self.name(), &DataType::Utf8)); + return Ok(Self::empty(self.name(), &DataType::Utf8)); } let self_iter = create_broadcasted_str_iter(self, expected_size); @@ -622,27 +609,27 @@ impl Utf8Array { Ok(UInt64Array::from((self.name(), Box::new(arrow_result)))) } - pub fn lower(&self) -> DaftResult { + pub fn lower(&self) -> DaftResult { self.unary_broadcasted_op(|val| val.to_lowercase().into()) } - pub fn upper(&self) -> DaftResult { + pub fn upper(&self) -> DaftResult { self.unary_broadcasted_op(|val| val.to_uppercase().into()) } - pub fn lstrip(&self) -> DaftResult { + pub fn lstrip(&self) -> DaftResult { self.unary_broadcasted_op(|val| val.trim_start().into()) } - pub fn rstrip(&self) -> DaftResult { + pub fn rstrip(&self) -> DaftResult { self.unary_broadcasted_op(|val| val.trim_end().into()) } - pub fn reverse(&self) -> DaftResult { + pub fn reverse(&self) -> DaftResult { self.unary_broadcasted_op(|val| val.chars().rev().collect::().into()) } - pub fn capitalize(&self) -> DaftResult { + pub fn capitalize(&self) -> DaftResult { self.unary_broadcasted_op(|val| { let mut chars = val.chars(); match chars.next() { @@ -658,7 +645,7 @@ impl Utf8Array { }) } - pub fn find(&self, substr: &Utf8Array) -> DaftResult { + pub fn find(&self, substr: &Self) -> DaftResult { let (is_full_null, expected_size) = parse_inputs(self, &[substr]) .map_err(|e| DaftError::ValueError(format!("Error in find: {e}")))?; if is_full_null { @@ -689,7 +676,7 @@ impl Utf8Array { Ok(result) } - pub fn like(&self, pattern: &Utf8Array) -> DaftResult { + pub fn like(&self, pattern: &Self) -> DaftResult { let (is_full_null, expected_size) = parse_inputs(self, &[pattern]) .map_err(|e| DaftError::ValueError(format!("Error in like: {e}")))?; if is_full_null { @@ -734,7 +721,7 @@ impl Utf8Array { Ok(result) } - pub fn ilike(&self, pattern: &Utf8Array) -> DaftResult { + pub fn ilike(&self, pattern: &Self) -> DaftResult { let (is_full_null, expected_size) = parse_inputs(self, &[pattern]) .map_err(|e| DaftError::ValueError(format!("Error in ilike: {e}")))?; if is_full_null { @@ -781,7 +768,7 @@ impl Utf8Array { Ok(result) } - pub fn left(&self, nchars: &DataArray) -> DaftResult + pub fn left(&self, nchars: &DataArray) -> DaftResult where I: DaftIntegerType, ::Native: Ord, @@ -789,14 +776,10 @@ impl Utf8Array { let (is_full_null, expected_size) = parse_inputs(self, &[nchars]) .map_err(|e| DaftError::ValueError(format!("Error in left: {e}")))?; if is_full_null { - return Ok(Utf8Array::full_null( - self.name(), - &DataType::Utf8, - expected_size, - )); + return Ok(Self::full_null(self.name(), &DataType::Utf8, expected_size)); } if expected_size == 0 { - return Ok(Utf8Array::empty(self.name(), &DataType::Utf8)); + return Ok(Self::empty(self.name(), &DataType::Utf8)); } fn left_most_chars(val: &str, n: usize) -> &str { @@ -808,7 +791,7 @@ impl Utf8Array { } let self_iter = create_broadcasted_str_iter(self, expected_size); - let result: Utf8Array = match nchars.len() { + let result: Self = match nchars.len() { 1 => { let n = nchars.get(0).unwrap(); let n: usize = NumCast::from(n).ok_or_else(|| { @@ -819,7 +802,7 @@ impl Utf8Array { let arrow_result = self_iter .map(|val| Some(left_most_chars(val?, n))) .collect::>(); - Utf8Array::from((self.name(), Box::new(arrow_result))) + Self::from((self.name(), Box::new(arrow_result))) } _ => { let arrow_result = self_iter @@ -837,14 +820,14 @@ impl Utf8Array { }) .collect::>>()?; - Utf8Array::from((self.name(), Box::new(arrow_result))) + Self::from((self.name(), Box::new(arrow_result))) } }; assert_eq!(result.len(), expected_size); Ok(result) } - pub fn right(&self, nchars: &DataArray) -> DaftResult + pub fn right(&self, nchars: &DataArray) -> DaftResult where I: DaftIntegerType, ::Native: Ord, @@ -852,14 +835,10 @@ impl Utf8Array { let (is_full_null, expected_size) = parse_inputs(self, &[nchars]) .map_err(|e| DaftError::ValueError(format!("Error in right: {e}")))?; if is_full_null { - return Ok(Utf8Array::full_null( - self.name(), - &DataType::Utf8, - expected_size, - )); + return Ok(Self::full_null(self.name(), &DataType::Utf8, expected_size)); } if expected_size == 0 { - return Ok(Utf8Array::empty(self.name(), &DataType::Utf8)); + return Ok(Self::empty(self.name(), &DataType::Utf8)); } fn right_most_chars(val: &str, nchar: usize) -> &str { @@ -872,7 +851,7 @@ impl Utf8Array { } let self_iter = create_broadcasted_str_iter(self, expected_size); - let result: Utf8Array = match nchars.len() { + let result: Self = match nchars.len() { 1 => { let n = nchars.get(0).unwrap(); let n: usize = NumCast::from(n).ok_or_else(|| { @@ -883,7 +862,7 @@ impl Utf8Array { let arrow_result = self_iter .map(|val| Some(right_most_chars(val?, n))) .collect::>(); - Utf8Array::from((self.name(), Box::new(arrow_result))) + Self::from((self.name(), Box::new(arrow_result))) } _ => { let arrow_result = self_iter @@ -901,7 +880,7 @@ impl Utf8Array { }) .collect::>>()?; - Utf8Array::from((self.name(), Box::new(arrow_result))) + Self::from((self.name(), Box::new(arrow_result))) } }; assert_eq!(result.len(), expected_size); @@ -993,7 +972,7 @@ impl Utf8Array { Ok(result) } - pub fn repeat(&self, n: &DataArray) -> DaftResult + pub fn repeat(&self, n: &DataArray) -> DaftResult where I: DaftIntegerType, ::Native: Ord, @@ -1001,19 +980,15 @@ impl Utf8Array { let (is_full_null, expected_size) = parse_inputs(self, &[n]) .map_err(|e| DaftError::ValueError(format!("Error in repeat: {e}")))?; if is_full_null { - return Ok(Utf8Array::full_null( - self.name(), - &DataType::Utf8, - expected_size, - )); + return Ok(Self::full_null(self.name(), &DataType::Utf8, expected_size)); } if expected_size == 0 { - return Ok(Utf8Array::empty(self.name(), &DataType::Utf8)); + return Ok(Self::empty(self.name(), &DataType::Utf8)); } let self_iter = create_broadcasted_str_iter(self, expected_size); - let result: Utf8Array = match n.len() { + let result: Self = match n.len() { 1 => { let n = n.get(0).unwrap(); let n: usize = NumCast::from(n).ok_or_else(|| { @@ -1024,7 +999,7 @@ impl Utf8Array { let arrow_result = self_iter .map(|val| Some(val?.repeat(n))) .collect::>(); - Utf8Array::from((self.name(), Box::new(arrow_result))) + Self::from((self.name(), Box::new(arrow_result))) } _ => { let arrow_result = self_iter @@ -1042,7 +1017,7 @@ impl Utf8Array { }) .collect::>>()?; - Utf8Array::from((self.name(), Box::new(arrow_result))) + Self::from((self.name(), Box::new(arrow_result))) } }; @@ -1054,7 +1029,7 @@ impl Utf8Array { &self, start: &DataArray, length: Option<&DataArray>, - ) -> DaftResult + ) -> DaftResult where I: DaftIntegerType, ::Native: Ord, @@ -1066,7 +1041,7 @@ impl Utf8Array { .map_err(|e| DaftError::ValueError(format!("Error in substr: {e}")))?; if is_full_null { - return Ok(Utf8Array::full_null(name, &DataType::Utf8, expected_size)); + return Ok(Self::full_null(name, &DataType::Utf8, expected_size)); } let self_iter = create_broadcasted_str_iter(self, expected_size); @@ -1172,9 +1147,9 @@ impl Utf8Array { pub fn pad( &self, length: &DataArray, - padchar: &Utf8Array, + padchar: &Self, placement: PadPlacement, - ) -> DaftResult + ) -> DaftResult where I: DaftIntegerType, ::Native: Ord, @@ -1216,11 +1191,7 @@ impl Utf8Array { || length.null_count() == length.len() || padchar.null_count() == padchar.len() { - return Ok(Utf8Array::full_null( - self.name(), - &DataType::Utf8, - expected_size, - )); + return Ok(Self::full_null(self.name(), &DataType::Utf8, expected_size)); } fn pad_str( @@ -1260,7 +1231,7 @@ impl Utf8Array { let self_iter = create_broadcasted_str_iter(self, expected_size); let padchar_iter = create_broadcasted_str_iter(padchar, expected_size); - let result: Utf8Array = match length.len() { + let result: Self = match length.len() { 1 => { let len = length.get(0).unwrap(); let len: usize = NumCast::from(len).ok_or_else(|| { @@ -1278,7 +1249,7 @@ impl Utf8Array { }) .collect::>>()?; - Utf8Array::from((self.name(), Box::new(arrow_result))) + Self::from((self.name(), Box::new(arrow_result))) } _ => { let length_iter = length.as_arrow().iter(); @@ -1298,7 +1269,7 @@ impl Utf8Array { }) .collect::>>()?; - Utf8Array::from((self.name(), Box::new(arrow_result))) + Self::from((self.name(), Box::new(arrow_result))) } }; @@ -1343,8 +1314,8 @@ impl Utf8Array { Ok(result) } - pub fn normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult { - Ok(Utf8Array::from_iter( + pub fn normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult { + Ok(Self::from_iter( self.name(), self.as_arrow().iter().map(|maybe_s| { if let Some(s) = maybe_s { @@ -1432,7 +1403,7 @@ impl Utf8Array { Ok(UInt64Array::from_iter(self.name(), iter)) } - fn unary_broadcasted_op(&self, operation: ScalarKernel) -> DaftResult + fn unary_broadcasted_op(&self, operation: ScalarKernel) -> DaftResult where ScalarKernel: Fn(&str) -> Cow<'_, str>, { @@ -1442,7 +1413,7 @@ impl Utf8Array { .map(|val| Some(operation(val?))) .collect::>() .with_validity(self_arrow.validity().cloned()); - Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) + Ok(Self::from((self.name(), Box::new(arrow_result)))) } } diff --git a/src/daft-core/src/array/pseudo_arrow/compute.rs b/src/daft-core/src/array/pseudo_arrow/compute.rs index 8f513a5857..65b11a69c1 100644 --- a/src/daft-core/src/array/pseudo_arrow/compute.rs +++ b/src/daft-core/src/array/pseudo_arrow/compute.rs @@ -26,6 +26,6 @@ impl PseudoArrowArray { let concatenated_validity = Bitmap::from_iter(bitmaps.iter().flat_map(|bitmap| bitmap.iter())); - PseudoArrowArray::new(concatenated_values.into(), Some(concatenated_validity)) + Self::new(concatenated_values.into(), Some(concatenated_validity)) } } diff --git a/src/daft-core/src/array/pseudo_arrow/mod.rs b/src/daft-core/src/array/pseudo_arrow/mod.rs index d54abaac7d..ff1b1ab9fb 100644 --- a/src/daft-core/src/array/pseudo_arrow/mod.rs +++ b/src/daft-core/src/array/pseudo_arrow/mod.rs @@ -298,7 +298,7 @@ impl Array for PseudoArrowArray { } fn with_validity(&self, validity: Option) -> Box { - Box::new(PseudoArrowArray { + Box::new(Self { values: self.values.clone(), validity, }) diff --git a/src/daft-core/src/array/pseudo_arrow/python.rs b/src/daft-core/src/array/pseudo_arrow/python.rs index 0beb5fd950..1bdc73cb2c 100644 --- a/src/daft-core/src/array/pseudo_arrow/python.rs +++ b/src/daft-core/src/array/pseudo_arrow/python.rs @@ -11,7 +11,7 @@ impl PseudoArrowArray { let validity: arrow2::bitmap::Bitmap = Python::with_gil(|py| { arrow2::bitmap::Bitmap::from_iter(pyobj_vec.iter().map(|pyobj| !pyobj.is_none(py))) }); - PseudoArrowArray::new(pyobj_vec.into(), Some(validity)) + Self::new(pyobj_vec.into(), Some(validity)) } pub fn to_pyobj_vec(&self) -> Vec { @@ -42,15 +42,10 @@ impl PseudoArrowArray { let (new_values, new_validity): (Vec, Vec) = { lhs.as_any() - .downcast_ref::>() + .downcast_ref::() .unwrap() .iter() - .zip( - rhs.as_any() - .downcast_ref::>() - .unwrap() - .iter(), - ) + .zip(rhs.as_any().downcast_ref::().unwrap().iter()) .zip(predicate.iter()) .map(|((self_val, other_val), pred_val)| match pred_val { None => None, @@ -66,6 +61,6 @@ impl PseudoArrowArray { let new_validity: Option = Some(Bitmap::from_iter(new_validity)); - PseudoArrowArray::new(new_values.into(), new_validity) + Self::new(new_values.into(), new_validity) } } diff --git a/src/daft-core/src/array/serdes.rs b/src/daft-core/src/array/serdes.rs index 9ba3905680..cc908c0dd6 100644 --- a/src/daft-core/src/array/serdes.rs +++ b/src/daft-core/src/array/serdes.rs @@ -27,7 +27,7 @@ where ::Item: serde::Serialize, { fn new(iter: I) -> Self { - IterSer { + Self { iter: RefCell::new(Some(iter)), } } diff --git a/src/daft-core/src/array/struct_array.rs b/src/daft-core/src/array/struct_array.rs index fb0c50fb25..996680ede5 100644 --- a/src/daft-core/src/array/struct_array.rs +++ b/src/daft-core/src/array/struct_array.rs @@ -64,7 +64,7 @@ impl StructArray { ) } - StructArray { + Self { field, children, validity, @@ -108,7 +108,7 @@ impl StructArray { growable .build() - .map(|s| s.downcast::().unwrap().clone()) + .map(|s| s.downcast::().unwrap().clone()) } pub fn len(&self) -> usize { diff --git a/src/daft-core/src/count_mode.rs b/src/daft-core/src/count_mode.rs index 7ef22f452f..0b1ea12368 100644 --- a/src/daft-core/src/count_mode.rs +++ b/src/daft-core/src/count_mode.rs @@ -40,7 +40,7 @@ impl CountMode { impl_bincode_py_state_serialization!(CountMode); impl CountMode { - pub fn iterator() -> std::slice::Iter<'static, CountMode> { + pub fn iterator() -> std::slice::Iter<'static, Self> { static COUNT_MODES: [CountMode; 3] = [CountMode::All, CountMode::Valid, CountMode::Null]; COUNT_MODES.iter() } @@ -51,13 +51,13 @@ impl FromStr for CountMode { fn from_str(count_mode: &str) -> DaftResult { match count_mode { - "all" => Ok(CountMode::All), - "valid" => Ok(CountMode::Valid), - "null" => Ok(CountMode::Null), + "all" => Ok(Self::All), + "valid" => Ok(Self::Valid), + "null" => Ok(Self::Null), _ => Err(DaftError::TypeError(format!( "Count mode {} is not supported; only the following modes are supported: {:?}", count_mode, - CountMode::iterator().as_slice() + Self::iterator().as_slice() ))), } } diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index df48b30524..86d84535e1 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -44,7 +44,7 @@ impl LogicalArrayImpl { &field.dtype.to_physical(), physical.data_type() ); - LogicalArrayImpl { + Self { physical, field, marker_: PhantomData, diff --git a/src/daft-core/src/join.rs b/src/daft-core/src/join.rs index 62746fbfb1..13e682fe14 100644 --- a/src/daft-core/src/join.rs +++ b/src/daft-core/src/join.rs @@ -38,7 +38,7 @@ impl JoinType { impl_bincode_py_state_serialization!(JoinType); impl JoinType { - pub fn iterator() -> std::slice::Iter<'static, JoinType> { + pub fn iterator() -> std::slice::Iter<'static, Self> { static JOIN_TYPES: [JoinType; 6] = [ JoinType::Inner, JoinType::Left, @@ -56,16 +56,16 @@ impl FromStr for JoinType { fn from_str(join_type: &str) -> DaftResult { match join_type { - "inner" => Ok(JoinType::Inner), - "left" => Ok(JoinType::Left), - "right" => Ok(JoinType::Right), - "outer" => Ok(JoinType::Outer), - "anti" => Ok(JoinType::Anti), - "semi" => Ok(JoinType::Semi), + "inner" => Ok(Self::Inner), + "left" => Ok(Self::Left), + "right" => Ok(Self::Right), + "outer" => Ok(Self::Outer), + "anti" => Ok(Self::Anti), + "semi" => Ok(Self::Semi), _ => Err(DaftError::TypeError(format!( "Join type {} is not supported; only the following types are supported: {:?}", join_type, - JoinType::iterator().as_slice() + Self::iterator().as_slice() ))), } } @@ -98,7 +98,7 @@ impl JoinStrategy { impl_bincode_py_state_serialization!(JoinStrategy); impl JoinStrategy { - pub fn iterator() -> std::slice::Iter<'static, JoinStrategy> { + pub fn iterator() -> std::slice::Iter<'static, Self> { static JOIN_STRATEGIES: [JoinStrategy; 3] = [ JoinStrategy::Hash, JoinStrategy::SortMerge, @@ -113,13 +113,13 @@ impl FromStr for JoinStrategy { fn from_str(join_strategy: &str) -> DaftResult { match join_strategy { - "hash" => Ok(JoinStrategy::Hash), - "sort_merge" => Ok(JoinStrategy::SortMerge), - "broadcast" => Ok(JoinStrategy::Broadcast), + "hash" => Ok(Self::Hash), + "sort_merge" => Ok(Self::SortMerge), + "broadcast" => Ok(Self::Broadcast), _ => Err(DaftError::TypeError(format!( "Join strategy {} is not supported; only the following strategies are supported: {:?}", join_strategy, - JoinStrategy::iterator().as_slice() + Self::iterator().as_slice() ))), } } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 5764d5d610..f57bf3f829 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -299,7 +299,7 @@ impl PySeries { Ok(self.series.argsort(descending)?.into()) } - pub fn hash(&self, seed: Option) -> PyResult { + pub fn hash(&self, seed: Option) -> PyResult { let seed_series; let mut seed_array = None; if let Some(s) = seed { @@ -710,7 +710,7 @@ impl PySeries { impl From for PySeries { fn from(value: series::Series) -> Self { - PySeries { series: value } + Self { series: value } } } diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index f1cac0d31b..c210d5cdb2 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -18,7 +18,7 @@ use crate::{ impl IntoSeries for DataArray where - ArrayWrapper>: SeriesLike, + ArrayWrapper: SeriesLike, { fn into_series(self) -> Series { Series { diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index bec7f069f7..9076907579 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -11,7 +11,7 @@ use crate::{ impl IntoSeries for LogicalArray where L: DaftLogicalType, - ArrayWrapper>: SeriesLike, + ArrayWrapper: SeriesLike, { fn into_series(self) -> Series { Series { diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index 92384296d5..0aa91d281c 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -87,7 +87,7 @@ impl Series { pub fn field(&self) -> &Field { self.inner.field() } - pub fn as_physical(&self) -> DaftResult { + pub fn as_physical(&self) -> DaftResult { let physical_dtype = self.data_type().to_physical(); if &physical_dtype == self.data_type() { Ok(self.clone()) @@ -108,7 +108,7 @@ impl Series { ) } - pub fn with_validity(&self, validity: Option) -> DaftResult { + pub fn with_validity(&self, validity: Option) -> DaftResult { self.inner.with_validity(validity) } diff --git a/src/daft-core/src/series/ops/abs.rs b/src/daft-core/src/series/ops/abs.rs index 1ea9b47587..6a32bc65bb 100644 --- a/src/daft-core/src/series/ops/abs.rs +++ b/src/daft-core/src/series/ops/abs.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn abs(&self) -> DaftResult { + pub fn abs(&self) -> DaftResult { match self.data_type() { DataType::Int8 => Ok(self.i8().unwrap().abs()?.into_series()), DataType::Int16 => Ok(self.i16().unwrap().abs()?.into_series()), diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 353c6ca25d..541fe5c556 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -14,7 +14,7 @@ use crate::{ }; impl Series { - pub fn count(&self, groups: Option<&GroupIndices>, mode: CountMode) -> DaftResult { + pub fn count(&self, groups: Option<&GroupIndices>, mode: CountMode) -> DaftResult { use crate::array::ops::DaftCountAggable; let s = self.as_physical()?; with_match_physical_daft_types!(s.data_type(), |$T| { @@ -25,7 +25,7 @@ impl Series { }) } - pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult { use crate::{array::ops::DaftSumAggable, datatypes::DataType::*}; match self.data_type() { @@ -94,7 +94,7 @@ impl Series { } } - pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { use crate::{array::ops::DaftApproxSketchAggable, datatypes::DataType::*}; // Upcast all numeric types to float64 and compute approx_sketch. @@ -119,7 +119,7 @@ impl Series { } } - pub fn merge_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn merge_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { use crate::{array::ops::DaftMergeSketchAggable, datatypes::DataType::*}; match self.data_type() { @@ -138,7 +138,7 @@ impl Series { } } - pub fn hll_merge(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn hll_merge(&self, groups: Option<&GroupIndices>) -> DaftResult { let downcasted_self = self.downcast::()?; let series = match groups { Some(groups) => downcasted_self.grouped_hll_merge(groups), @@ -148,7 +148,7 @@ impl Series { Ok(series) } - pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult { use crate::{array::ops::DaftMeanAggable, datatypes::DataType::*}; // Upcast all numeric types to float64 and use f64 mean kernel. @@ -169,19 +169,15 @@ impl Series { } } - pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult { self.inner.min(groups) } - pub fn max(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn max(&self, groups: Option<&GroupIndices>) -> DaftResult { self.inner.max(groups) } - pub fn any_value( - &self, - groups: Option<&GroupIndices>, - ignore_nulls: bool, - ) -> DaftResult { + pub fn any_value(&self, groups: Option<&GroupIndices>, ignore_nulls: bool) -> DaftResult { let indices = match groups { Some(groups) => { if self.data_type().is_null() { @@ -212,17 +208,17 @@ impl Series { } }; - self.take(&Series::from_arrow( + self.take(&Self::from_arrow( Field::new("", DataType::UInt64).into(), indices, )?) } - pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult { self.inner.agg_list(groups) } - pub fn agg_concat(&self, groups: Option<&GroupIndices>) -> DaftResult { + pub fn agg_concat(&self, groups: Option<&GroupIndices>) -> DaftResult { use crate::array::ops::DaftConcatAggable; match self.data_type() { DataType::List(..) => { diff --git a/src/daft-core/src/series/ops/between.rs b/src/daft-core/src/series/ops/between.rs index 55f9c2e283..4e3d8c89d5 100644 --- a/src/daft-core/src/series/ops/between.rs +++ b/src/daft-core/src/series/ops/between.rs @@ -10,7 +10,7 @@ use crate::{ }; impl Series { - pub fn between(&self, lower: &Series, upper: &Series) -> DaftResult { + pub fn between(&self, lower: &Self, upper: &Self) -> DaftResult { let (_output_type, _intermediate, lower_comp_type) = InferDataType::from(self.data_type()) .comparison_op(&InferDataType::from(lower.data_type()))?; let (_output_type, _intermediate, upper_comp_type) = InferDataType::from(self.data_type()) @@ -29,11 +29,7 @@ impl Series { .downcast::()? .clone() .into_series()), - DataType::Null => Ok(Series::full_null( - self.name(), - &DataType::Boolean, - self.len(), - )), + DataType::Null => Ok(Self::full_null(self.name(), &DataType::Boolean, self.len())), _ => with_match_numeric_daft_types!(comp_type, |$T| { let casted_value = it_value.cast(&comp_type)?; let casted_lower = it_lower.cast(&comp_type)?; diff --git a/src/daft-core/src/series/ops/broadcast.rs b/src/daft-core/src/series/ops/broadcast.rs index 809364c4b8..e0bd28c9ea 100644 --- a/src/daft-core/src/series/ops/broadcast.rs +++ b/src/daft-core/src/series/ops/broadcast.rs @@ -2,7 +2,7 @@ use common_error::DaftResult; use crate::series::Series; impl Series { - pub fn broadcast(&self, num: usize) -> DaftResult { + pub fn broadcast(&self, num: usize) -> DaftResult { self.inner.broadcast(num) } } diff --git a/src/daft-core/src/series/ops/cast.rs b/src/daft-core/src/series/ops/cast.rs index 3d31edddf4..c83f24b33b 100644 --- a/src/daft-core/src/series/ops/cast.rs +++ b/src/daft-core/src/series/ops/cast.rs @@ -3,7 +3,7 @@ use common_error::DaftResult; use crate::{datatypes::DataType, series::Series}; impl Series { - pub fn cast(&self, datatype: &DataType) -> DaftResult { + pub fn cast(&self, datatype: &DataType) -> DaftResult { self.inner.cast(datatype) } } diff --git a/src/daft-core/src/series/ops/cbrt.rs b/src/daft-core/src/series/ops/cbrt.rs index 3d6cb2b8de..8eb1a46758 100644 --- a/src/daft-core/src/series/ops/cbrt.rs +++ b/src/daft-core/src/series/ops/cbrt.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn cbrt(&self) -> DaftResult { + pub fn cbrt(&self) -> DaftResult { let casted_dtype = self.to_floating_data_type()?; let casted_self = self .cast(&casted_dtype) diff --git a/src/daft-core/src/series/ops/ceil.rs b/src/daft-core/src/series/ops/ceil.rs index 0097acc353..e65e416fe1 100644 --- a/src/daft-core/src/series/ops/ceil.rs +++ b/src/daft-core/src/series/ops/ceil.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn ceil(&self) -> DaftResult { + pub fn ceil(&self) -> DaftResult { match self.data_type() { DataType::Int8 | DataType::Int16 diff --git a/src/daft-core/src/series/ops/comparison.rs b/src/daft-core/src/series/ops/comparison.rs index df411a6c9f..67ac7c66ec 100644 --- a/src/daft-core/src/series/ops/comparison.rs +++ b/src/daft-core/src/series/ops/comparison.rs @@ -14,7 +14,7 @@ macro_rules! call_inner { }; } -impl DaftCompare<&Series> for Series { +impl DaftCompare<&Self> for Series { type Output = DaftResult; call_inner!(equal); @@ -25,8 +25,8 @@ impl DaftCompare<&Series> for Series { call_inner!(gte); } -impl DaftLogical<&Series> for Series { - type Output = DaftResult; +impl DaftLogical<&Self> for Series { + type Output = DaftResult; call_inner!(and); call_inner!(or); diff --git a/src/daft-core/src/series/ops/concat.rs b/src/daft-core/src/series/ops/concat.rs index c83038ef82..9103255faf 100644 --- a/src/daft-core/src/series/ops/concat.rs +++ b/src/daft-core/src/series/ops/concat.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn concat(series: &[&Series]) -> DaftResult { + pub fn concat(series: &[&Self]) -> DaftResult { if series.is_empty() { return Err(DaftError::ValueError( "Need at least 1 series to perform concat".to_string(), diff --git a/src/daft-core/src/series/ops/exp.rs b/src/daft-core/src/series/ops/exp.rs index 95a43fc3c2..915e8baefe 100644 --- a/src/daft-core/src/series/ops/exp.rs +++ b/src/daft-core/src/series/ops/exp.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn exp(&self) -> DaftResult { + pub fn exp(&self) -> DaftResult { match self.data_type() { DataType::Float32 => Ok(self.f32().unwrap().exp()?.into_series()), DataType::Float64 => Ok(self.f64().unwrap().exp()?.into_series()), diff --git a/src/daft-core/src/series/ops/filter.rs b/src/daft-core/src/series/ops/filter.rs index 06488ac92e..15847237d1 100644 --- a/src/daft-core/src/series/ops/filter.rs +++ b/src/daft-core/src/series/ops/filter.rs @@ -3,7 +3,7 @@ use common_error::{DaftError, DaftResult}; use crate::{datatypes::BooleanArray, series::Series}; impl Series { - pub fn filter(&self, mask: &BooleanArray) -> DaftResult { + pub fn filter(&self, mask: &BooleanArray) -> DaftResult { match (self.len(), mask.len()) { (_, 1) => { if Some(true) == mask.get(0) { diff --git a/src/daft-core/src/series/ops/float.rs b/src/daft-core/src/series/ops/float.rs index 92fc9c48b7..c2da199591 100644 --- a/src/daft-core/src/series/ops/float.rs +++ b/src/daft-core/src/series/ops/float.rs @@ -6,21 +6,21 @@ use crate::{ }; impl Series { - pub fn is_nan(&self) -> DaftResult { + pub fn is_nan(&self) -> DaftResult { use crate::array::ops::DaftIsNan; with_match_float_and_null_daft_types!(self.data_type(), |$T| { Ok(DaftIsNan::is_nan(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series()) }) } - pub fn is_inf(&self) -> DaftResult { + pub fn is_inf(&self) -> DaftResult { use crate::array::ops::DaftIsInf; with_match_float_and_null_daft_types!(self.data_type(), |$T| { Ok(DaftIsInf::is_inf(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series()) }) } - pub fn not_nan(&self) -> DaftResult { + pub fn not_nan(&self) -> DaftResult { use crate::array::ops::DaftNotNan; with_match_float_and_null_daft_types!(self.data_type(), |$T| { Ok(DaftNotNan::not_nan(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series()) diff --git a/src/daft-core/src/series/ops/floor.rs b/src/daft-core/src/series/ops/floor.rs index c574c3da3e..f216c3412d 100644 --- a/src/daft-core/src/series/ops/floor.rs +++ b/src/daft-core/src/series/ops/floor.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn floor(&self) -> DaftResult { + pub fn floor(&self) -> DaftResult { match self.data_type() { DataType::Int8 | DataType::Int16 diff --git a/src/daft-core/src/series/ops/if_else.rs b/src/daft-core/src/series/ops/if_else.rs index 26cc553f54..fa573d4f40 100644 --- a/src/daft-core/src/series/ops/if_else.rs +++ b/src/daft-core/src/series/ops/if_else.rs @@ -4,7 +4,7 @@ use super::cast_series_to_supertype; use crate::series::Series; impl Series { - pub fn if_else(&self, other: &Series, predicate: &Series) -> DaftResult { + pub fn if_else(&self, other: &Self, predicate: &Self) -> DaftResult { let casted_series = cast_series_to_supertype(&[self, other])?; assert!(casted_series.len() == 2); diff --git a/src/daft-core/src/series/ops/is_in.rs b/src/daft-core/src/series/ops/is_in.rs index 7acc601a37..7b2386b745 100644 --- a/src/daft-core/src/series/ops/is_in.rs +++ b/src/daft-core/src/series/ops/is_in.rs @@ -14,7 +14,7 @@ fn default(name: &str, size: usize) -> DaftResult { } impl Series { - pub fn is_in(&self, items: &Self) -> DaftResult { + pub fn is_in(&self, items: &Self) -> DaftResult { if items.is_empty() { return default(self.name(), self.len()); } diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index d00f5440c2..d9a17dd087 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -7,7 +7,7 @@ use crate::{ }; impl Series { - pub fn explode(&self) -> DaftResult { + pub fn explode(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.explode(), DataType::FixedSizeList(..) => self.fixed_size_list()?.explode(), @@ -55,7 +55,7 @@ impl Series { } } - pub fn list_get(&self, idx: &Series, default: &Series) -> DaftResult { + pub fn list_get(&self, idx: &Self, default: &Self) -> DaftResult { let idx = idx.cast(&DataType::Int64)?; let idx_arr = idx.i64().unwrap(); @@ -69,7 +69,7 @@ impl Series { } } - pub fn list_slice(&self, start: &Series, end: &Series) -> DaftResult { + pub fn list_slice(&self, start: &Self, end: &Self) -> DaftResult { let start = start.cast(&DataType::Int64)?; let start_arr = start.i64().unwrap(); let end_arr = if end.data_type().is_integer() { @@ -89,7 +89,7 @@ impl Series { } } - pub fn list_chunk(&self, size: usize) -> DaftResult { + pub fn list_chunk(&self, size: usize) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.get_chunks(size), DataType::FixedSizeList(..) => self.fixed_size_list()?.get_chunks(size), @@ -99,7 +99,7 @@ impl Series { } } - pub fn list_sum(&self) -> DaftResult { + pub fn list_sum(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.sum(), DataType::FixedSizeList(..) => self.fixed_size_list()?.sum(), @@ -110,7 +110,7 @@ impl Series { } } - pub fn list_mean(&self) -> DaftResult { + pub fn list_mean(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.mean(), DataType::FixedSizeList(..) => self.fixed_size_list()?.mean(), @@ -121,7 +121,7 @@ impl Series { } } - pub fn list_min(&self) -> DaftResult { + pub fn list_min(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.min(), DataType::FixedSizeList(..) => self.fixed_size_list()?.min(), @@ -132,7 +132,7 @@ impl Series { } } - pub fn list_max(&self) -> DaftResult { + pub fn list_max(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.max(), DataType::FixedSizeList(..) => self.fixed_size_list()?.max(), @@ -143,7 +143,7 @@ impl Series { } } - pub fn list_sort(&self, desc: &Series) -> DaftResult { + pub fn list_sort(&self, desc: &Self) -> DaftResult { let desc_arr = desc.bool()?; match self.data_type() { diff --git a/src/daft-core/src/series/ops/log.rs b/src/daft-core/src/series/ops/log.rs index 73ded27a88..842fb7cb0a 100644 --- a/src/daft-core/src/series/ops/log.rs +++ b/src/daft-core/src/series/ops/log.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn log2(&self) -> DaftResult { + pub fn log2(&self) -> DaftResult { match self.data_type() { DataType::Int8 | DataType::Int16 @@ -28,7 +28,7 @@ impl Series { } } - pub fn log10(&self) -> DaftResult { + pub fn log10(&self) -> DaftResult { match self.data_type() { DataType::Int8 | DataType::Int16 @@ -50,7 +50,7 @@ impl Series { } } - pub fn log(&self, base: f64) -> DaftResult { + pub fn log(&self, base: f64) -> DaftResult { match self.data_type() { DataType::Int8 | DataType::Int16 @@ -72,7 +72,7 @@ impl Series { } } - pub fn ln(&self) -> DaftResult { + pub fn ln(&self) -> DaftResult { use crate::series::array_impl::IntoSeries; match self.data_type() { DataType::Int8 diff --git a/src/daft-core/src/series/ops/map.rs b/src/daft-core/src/series/ops/map.rs index 85461b1fe0..b624cd8aac 100644 --- a/src/daft-core/src/series/ops/map.rs +++ b/src/daft-core/src/series/ops/map.rs @@ -3,7 +3,7 @@ use common_error::{DaftError, DaftResult}; use crate::{datatypes::DataType, series::Series}; impl Series { - pub fn map_get(&self, key: &Series) -> DaftResult { + pub fn map_get(&self, key: &Self) -> DaftResult { match self.data_type() { DataType::Map(_) => self.map()?.map_get(key), dt => Err(DaftError::TypeError(format!( diff --git a/src/daft-core/src/series/ops/minhash.rs b/src/daft-core/src/series/ops/minhash.rs index 527dc75fdd..a6a7bb9247 100644 --- a/src/daft-core/src/series/ops/minhash.rs +++ b/src/daft-core/src/series/ops/minhash.rs @@ -7,7 +7,7 @@ use crate::{ }; impl Series { - pub fn minhash(&self, num_hashes: usize, ngram_size: usize, seed: u32) -> DaftResult { + pub fn minhash(&self, num_hashes: usize, ngram_size: usize, seed: u32) -> DaftResult { match self.data_type() { DataType::Utf8 => Ok(self .utf8()? diff --git a/src/daft-core/src/series/ops/not.rs b/src/daft-core/src/series/ops/not.rs index c6a8216614..8372fde4be 100644 --- a/src/daft-core/src/series/ops/not.rs +++ b/src/daft-core/src/series/ops/not.rs @@ -16,7 +16,7 @@ impl Not for &Series { } impl Not for Series { - type Output = DaftResult; + type Output = DaftResult; fn not(self) -> Self::Output { (&self).not() } diff --git a/src/daft-core/src/series/ops/null.rs b/src/daft-core/src/series/ops/null.rs index 00df3b5860..d8ffeb0933 100644 --- a/src/daft-core/src/series/ops/null.rs +++ b/src/daft-core/src/series/ops/null.rs @@ -3,15 +3,15 @@ use common_error::DaftResult; use crate::series::Series; impl Series { - pub fn is_null(&self) -> DaftResult { + pub fn is_null(&self) -> DaftResult { self.inner.is_null() } - pub fn not_null(&self) -> DaftResult { + pub fn not_null(&self) -> DaftResult { self.inner.not_null() } - pub fn fill_null(&self, fill_value: &Series) -> DaftResult { + pub fn fill_null(&self, fill_value: &Self) -> DaftResult { let predicate = self.not_null()?; self.if_else(fill_value, &predicate) } diff --git a/src/daft-core/src/series/ops/repeat.rs b/src/daft-core/src/series/ops/repeat.rs index eddde7e463..1bd1e438cc 100644 --- a/src/daft-core/src/series/ops/repeat.rs +++ b/src/daft-core/src/series/ops/repeat.rs @@ -6,6 +6,6 @@ use crate::series::Series; impl Series { pub(crate) fn repeat(&self, n: usize) -> DaftResult { let many_self = std::iter::repeat(self).take(n).collect_vec(); - Series::concat(&many_self) + Self::concat(&many_self) } } diff --git a/src/daft-core/src/series/ops/round.rs b/src/daft-core/src/series/ops/round.rs index 6f968063fb..8c74bd90c5 100644 --- a/src/daft-core/src/series/ops/round.rs +++ b/src/daft-core/src/series/ops/round.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn round(&self, decimal: i32) -> DaftResult { + pub fn round(&self, decimal: i32) -> DaftResult { match self.data_type() { DataType::Int8 | DataType::Int16 diff --git a/src/daft-core/src/series/ops/shift.rs b/src/daft-core/src/series/ops/shift.rs index 1ba5275ae5..9f6cfcc38b 100644 --- a/src/daft-core/src/series/ops/shift.rs +++ b/src/daft-core/src/series/ops/shift.rs @@ -3,7 +3,7 @@ use common_error::{DaftError, DaftResult}; use crate::{datatypes::DataType, series::Series}; impl Series { - pub fn shift_left(&self, bits: &Self) -> DaftResult { + pub fn shift_left(&self, bits: &Self) -> DaftResult { use crate::series::array_impl::IntoSeries; if !bits.data_type().is_integer() { return Err(DaftError::TypeError(format!( @@ -52,7 +52,7 @@ impl Series { } } - pub fn shift_right(&self, bits: &Self) -> DaftResult { + pub fn shift_right(&self, bits: &Self) -> DaftResult { use crate::series::array_impl::IntoSeries; if !bits.data_type().is_integer() { return Err(DaftError::TypeError(format!( diff --git a/src/daft-core/src/series/ops/sign.rs b/src/daft-core/src/series/ops/sign.rs index 53ecb67088..aedc027502 100644 --- a/src/daft-core/src/series/ops/sign.rs +++ b/src/daft-core/src/series/ops/sign.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn sign(&self) -> DaftResult { + pub fn sign(&self) -> DaftResult { match self.data_type() { DataType::UInt8 => Ok(self.u8().unwrap().sign_unsigned()?.into_series()), DataType::UInt16 => Ok(self.u16().unwrap().sign_unsigned()?.into_series()), diff --git a/src/daft-core/src/series/ops/sketch_percentile.rs b/src/daft-core/src/series/ops/sketch_percentile.rs index 23e85b6d89..0cebfa739c 100644 --- a/src/daft-core/src/series/ops/sketch_percentile.rs +++ b/src/daft-core/src/series/ops/sketch_percentile.rs @@ -7,7 +7,7 @@ impl Series { &self, percentiles: &[f64], force_list_output: bool, - ) -> DaftResult { + ) -> DaftResult { use crate::datatypes::DataType::*; match self.data_type() { diff --git a/src/daft-core/src/series/ops/sort.rs b/src/daft-core/src/series/ops/sort.rs index 4c591cc744..48ad1288ba 100644 --- a/src/daft-core/src/series/ops/sort.rs +++ b/src/daft-core/src/series/ops/sort.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn argsort(&self, descending: bool) -> DaftResult { + pub fn argsort(&self, descending: bool) -> DaftResult { let series = self.as_physical()?; with_match_comparable_daft_types!(series.data_type(), |$T| { let downcasted = series.downcast::<<$T as DaftDataType>::ArrayType>()?; @@ -14,7 +14,7 @@ impl Series { }) } - pub fn argsort_multikey(sort_keys: &[Series], descending: &[bool]) -> DaftResult { + pub fn argsort_multikey(sort_keys: &[Self], descending: &[bool]) -> DaftResult { if sort_keys.len() != descending.len() { return Err(DaftError::ValueError(format!( "sort_keys and descending length must match, got {} vs {}", diff --git a/src/daft-core/src/series/ops/sqrt.rs b/src/daft-core/src/series/ops/sqrt.rs index 614f8f9fc2..f127b84d32 100644 --- a/src/daft-core/src/series/ops/sqrt.rs +++ b/src/daft-core/src/series/ops/sqrt.rs @@ -6,7 +6,7 @@ use crate::{ }; impl Series { - pub fn sqrt(&self) -> DaftResult { + pub fn sqrt(&self) -> DaftResult { let casted_dtype = self.to_floating_data_type()?; let casted_self = self .cast(&casted_dtype) diff --git a/src/daft-core/src/series/ops/struct_.rs b/src/daft-core/src/series/ops/struct_.rs index 4dabb8b0f3..2c5421e13a 100644 --- a/src/daft-core/src/series/ops/struct_.rs +++ b/src/daft-core/src/series/ops/struct_.rs @@ -3,7 +3,7 @@ use common_error::{DaftError, DaftResult}; use crate::{datatypes::DataType, series::Series}; impl Series { - pub fn struct_get(&self, name: &str) -> DaftResult { + pub fn struct_get(&self, name: &str) -> DaftResult { match self.data_type() { DataType::Struct(_) => self.struct_()?.get(name), dt => Err(DaftError::TypeError(format!( diff --git a/src/daft-core/src/series/ops/take.rs b/src/daft-core/src/series/ops/take.rs index 9ff218757e..80ec81f958 100644 --- a/src/daft-core/src/series/ops/take.rs +++ b/src/daft-core/src/series/ops/take.rs @@ -8,19 +8,19 @@ use crate::{ }; impl Series { - pub fn head(&self, num: usize) -> DaftResult { + pub fn head(&self, num: usize) -> DaftResult { if num >= self.len() { return Ok(self.clone()); } self.inner.head(num) } - pub fn slice(&self, start: usize, end: usize) -> DaftResult { + pub fn slice(&self, start: usize, end: usize) -> DaftResult { let l = self.len(); self.inner.slice(start.min(l), end.min(l)) } - pub fn take(&self, idx: &Series) -> DaftResult { + pub fn take(&self, idx: &Self) -> DaftResult { self.inner.take(idx) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index d4fe19bde3..64d112a984 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -8,10 +8,7 @@ use crate::{ }; impl Series { - pub fn with_utf8_array( - &self, - f: impl Fn(&Utf8Array) -> DaftResult, - ) -> DaftResult { + pub fn with_utf8_array(&self, f: impl Fn(&Utf8Array) -> DaftResult) -> DaftResult { match self.data_type() { DataType::Utf8 => f(self.utf8()?), DataType::Null => Ok(self.clone()), @@ -21,44 +18,44 @@ impl Series { } } - pub fn utf8_endswith(&self, pattern: &Series) -> DaftResult { + pub fn utf8_endswith(&self, pattern: &Self) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| Ok(arr.endswith(pattern_arr)?.into_series())) }) } - pub fn utf8_startswith(&self, pattern: &Series) -> DaftResult { + pub fn utf8_startswith(&self, pattern: &Self) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| Ok(arr.startswith(pattern_arr)?.into_series())) }) } - pub fn utf8_contains(&self, pattern: &Series) -> DaftResult { + pub fn utf8_contains(&self, pattern: &Self) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| Ok(arr.contains(pattern_arr)?.into_series())) }) } - pub fn utf8_match(&self, pattern: &Series) -> DaftResult { + pub fn utf8_match(&self, pattern: &Self) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| Ok(arr.match_(pattern_arr)?.into_series())) }) } - pub fn utf8_split(&self, pattern: &Series, regex: bool) -> DaftResult { + pub fn utf8_split(&self, pattern: &Self, regex: bool) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| Ok(arr.split(pattern_arr, regex)?.into_series())) }) } - pub fn utf8_extract(&self, pattern: &Series, index: usize) -> DaftResult { + pub fn utf8_extract(&self, pattern: &Self, index: usize) -> DaftResult { self.with_utf8_array(|arr| { pattern .with_utf8_array(|pattern_arr| Ok(arr.extract(pattern_arr, index)?.into_series())) }) } - pub fn utf8_extract_all(&self, pattern: &Series, index: usize) -> DaftResult { + pub fn utf8_extract_all(&self, pattern: &Self, index: usize) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| { Ok(arr.extract_all(pattern_arr, index)?.into_series()) @@ -68,10 +65,10 @@ impl Series { pub fn utf8_replace( &self, - pattern: &Series, - replacement: &Series, + pattern: &Self, + replacement: &Self, regex: bool, - ) -> DaftResult { + ) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| { replacement.with_utf8_array(|replacement_arr| { @@ -83,39 +80,39 @@ impl Series { }) } - pub fn utf8_length(&self) -> DaftResult { + pub fn utf8_length(&self) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.length()?.into_series())) } - pub fn utf8_length_bytes(&self) -> DaftResult { + pub fn utf8_length_bytes(&self) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.length_bytes()?.into_series())) } - pub fn utf8_lower(&self) -> DaftResult { + pub fn utf8_lower(&self) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.lower()?.into_series())) } - pub fn utf8_upper(&self) -> DaftResult { + pub fn utf8_upper(&self) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.upper()?.into_series())) } - pub fn utf8_lstrip(&self) -> DaftResult { + pub fn utf8_lstrip(&self) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.lstrip()?.into_series())) } - pub fn utf8_rstrip(&self) -> DaftResult { + pub fn utf8_rstrip(&self) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.rstrip()?.into_series())) } - pub fn utf8_reverse(&self) -> DaftResult { + pub fn utf8_reverse(&self) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.reverse()?.into_series())) } - pub fn utf8_capitalize(&self) -> DaftResult { + pub fn utf8_capitalize(&self) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.capitalize()?.into_series())) } - pub fn utf8_left(&self, nchars: &Series) -> DaftResult { + pub fn utf8_left(&self, nchars: &Self) -> DaftResult { self.with_utf8_array(|arr| { if nchars.data_type().is_integer() { with_match_integer_daft_types!(nchars.data_type(), |$T| { @@ -132,7 +129,7 @@ impl Series { }) } - pub fn utf8_right(&self, nchars: &Series) -> DaftResult { + pub fn utf8_right(&self, nchars: &Self) -> DaftResult { self.with_utf8_array(|arr| { if nchars.data_type().is_integer() { with_match_integer_daft_types!(nchars.data_type(), |$T| { @@ -149,13 +146,13 @@ impl Series { }) } - pub fn utf8_find(&self, substr: &Series) -> DaftResult { + pub fn utf8_find(&self, substr: &Self) -> DaftResult { self.with_utf8_array(|arr| { substr.with_utf8_array(|substr_arr| Ok(arr.find(substr_arr)?.into_series())) }) } - pub fn utf8_lpad(&self, length: &Series, pad: &Series) -> DaftResult { + pub fn utf8_lpad(&self, length: &Self, pad: &Self) -> DaftResult { self.with_utf8_array(|arr| { pad.with_utf8_array(|pad_arr| { if length.data_type().is_integer() { @@ -174,7 +171,7 @@ impl Series { }) } - pub fn utf8_rpad(&self, length: &Series, pad: &Series) -> DaftResult { + pub fn utf8_rpad(&self, length: &Self, pad: &Self) -> DaftResult { self.with_utf8_array(|arr| { pad.with_utf8_array(|pad_arr| { if length.data_type().is_integer() { @@ -193,7 +190,7 @@ impl Series { }) } - pub fn utf8_repeat(&self, n: &Series) -> DaftResult { + pub fn utf8_repeat(&self, n: &Self) -> DaftResult { self.with_utf8_array(|arr| { if n.data_type().is_integer() { with_match_integer_daft_types!(n.data_type(), |$T| { @@ -210,19 +207,19 @@ impl Series { }) } - pub fn utf8_like(&self, pattern: &Series) -> DaftResult { + pub fn utf8_like(&self, pattern: &Self) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| Ok(arr.like(pattern_arr)?.into_series())) }) } - pub fn utf8_ilike(&self, pattern: &Series) -> DaftResult { + pub fn utf8_ilike(&self, pattern: &Self) -> DaftResult { self.with_utf8_array(|arr| { pattern.with_utf8_array(|pattern_arr| Ok(arr.ilike(pattern_arr)?.into_series())) }) } - pub fn utf8_substr(&self, start: &Series, length: &Series) -> DaftResult { + pub fn utf8_substr(&self, start: &Self, length: &Self) -> DaftResult { self.with_utf8_array(|arr| { if start.data_type().is_integer() { with_match_integer_daft_types!(start.data_type(), |$T| { @@ -250,24 +247,24 @@ impl Series { }) } - pub fn utf8_to_date(&self, format: &str) -> DaftResult { + pub fn utf8_to_date(&self, format: &str) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.to_date(format)?.into_series())) } - pub fn utf8_to_datetime(&self, format: &str, timezone: Option<&str>) -> DaftResult { + pub fn utf8_to_datetime(&self, format: &str, timezone: Option<&str>) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.to_datetime(format, timezone)?.into_series())) } - pub fn utf8_normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult { + pub fn utf8_normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.normalize(opts)?.into_series())) } pub fn utf8_count_matches( &self, - patterns: &Series, + patterns: &Self, whole_word: bool, case_sensitive: bool, - ) -> DaftResult { + ) -> DaftResult { self.with_utf8_array(|arr| { patterns.with_utf8_array(|pattern_arr| { Ok(arr diff --git a/src/daft-csv/Cargo.toml b/src/daft-csv/Cargo.toml index d672e30f35..dde511422b 100644 --- a/src/daft-csv/Cargo.toml +++ b/src/daft-csv/Cargo.toml @@ -25,6 +25,9 @@ rstest = {workspace = true} [features] python = ["dep:pyo3", "common-error/python", "common-py-serde/python", "daft-core/python", "daft-io/python", "daft-table/python", "daft-dsl/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-csv" diff --git a/src/daft-csv/src/lib.rs b/src/daft-csv/src/lib.rs index 17d254d520..b49245edab 100644 --- a/src/daft-csv/src/lib.rs +++ b/src/daft-csv/src/lib.rs @@ -43,17 +43,17 @@ pub enum Error { } impl From for DaftError { - fn from(err: Error) -> DaftError { + fn from(err: Error) -> Self { match err { Error::IOError { source } => source.into(), - _ => DaftError::External(err.into()), + _ => Self::External(err.into()), } } } impl From for Error { fn from(err: daft_io::Error) -> Self { - Error::IOError { source: err } + Self::IOError { source: err } } } diff --git a/src/daft-decoding/Cargo.toml b/src/daft-decoding/Cargo.toml index 89a17866c2..7874436f04 100644 --- a/src/daft-decoding/Cargo.toml +++ b/src/daft-decoding/Cargo.toml @@ -6,6 +6,9 @@ csv-async = "1.2.6" fast-float = "0.2.0" simdutf8 = "0.1.3" +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-decoding" diff --git a/src/daft-dsl/Cargo.toml b/src/daft-dsl/Cargo.toml index e9a236d6dd..cc72281e2e 100644 --- a/src/daft-dsl/Cargo.toml +++ b/src/daft-dsl/Cargo.toml @@ -19,6 +19,9 @@ typetag = "0.2.16" python = ["dep:pyo3", "common-error/python", "daft-core/python", "common-treenode/python", "common-py-serde/python", "common-resource-request/python"] test-utils = [] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-dsl" diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index f8c5deb247..48249355fc 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -403,7 +403,7 @@ impl AggExpr { } } - pub fn from_name_and_child_expr(name: &str, child: ExprRef) -> DaftResult { + pub fn from_name_and_child_expr(name: &str, child: ExprRef) -> DaftResult { use AggExpr::*; match name { "count" => Ok(Count(child, CountMode::Valid)), @@ -422,12 +422,12 @@ impl AggExpr { impl From<&AggExpr> for ExprRef { fn from(agg_expr: &AggExpr) -> Self { - Arc::new(Expr::Agg(agg_expr.clone())) + Self::new(Expr::Agg(agg_expr.clone())) } } -impl AsRef for Expr { - fn as_ref(&self) -> &Expr { +impl AsRef for Expr { + fn as_ref(&self) -> &Self { self } } @@ -438,11 +438,11 @@ impl Expr { } pub fn alias>>(self: &ExprRef, name: S) -> ExprRef { - Expr::Alias(self.clone(), name.into()).into() + Self::Alias(self.clone(), name.into()).into() } pub fn if_else(self: ExprRef, if_true: ExprRef, if_false: ExprRef) -> ExprRef { - Expr::IfElse { + Self::IfElse { if_true, if_false, predicate: self, @@ -451,19 +451,19 @@ impl Expr { } pub fn cast(self: ExprRef, dtype: &DataType) -> ExprRef { - Expr::Cast(self, dtype.clone()).into() + Self::Cast(self, dtype.clone()).into() } pub fn count(self: ExprRef, mode: CountMode) -> ExprRef { - Expr::Agg(AggExpr::Count(self, mode)).into() + Self::Agg(AggExpr::Count(self, mode)).into() } pub fn sum(self: ExprRef) -> ExprRef { - Expr::Agg(AggExpr::Sum(self)).into() + Self::Agg(AggExpr::Sum(self)).into() } pub fn approx_count_distinct(self: ExprRef) -> ExprRef { - Expr::Agg(AggExpr::ApproxCountDistinct(self)).into() + Self::Agg(AggExpr::ApproxCountDistinct(self)).into() } pub fn approx_percentiles( @@ -471,7 +471,7 @@ impl Expr { percentiles: &[f64], force_list_output: bool, ) -> ExprRef { - Expr::Agg(AggExpr::ApproxPercentile(ApproxPercentileParams { + Self::Agg(AggExpr::ApproxPercentile(ApproxPercentileParams { child: self, percentiles: percentiles.iter().map(|f| FloatWrapper(*f)).collect(), force_list_output, @@ -484,7 +484,7 @@ impl Expr { percentiles: &[f64], force_list_output: bool, ) -> ExprRef { - Expr::Function { + Self::Function { func: FunctionExpr::Sketch(SketchExpr::Percentile { percentiles: HashableVecPercentiles(percentiles.to_vec()), force_list_output, @@ -495,52 +495,52 @@ impl Expr { } pub fn mean(self: ExprRef) -> ExprRef { - Expr::Agg(AggExpr::Mean(self)).into() + Self::Agg(AggExpr::Mean(self)).into() } pub fn min(self: ExprRef) -> ExprRef { - Expr::Agg(AggExpr::Min(self)).into() + Self::Agg(AggExpr::Min(self)).into() } pub fn max(self: ExprRef) -> ExprRef { - Expr::Agg(AggExpr::Max(self)).into() + Self::Agg(AggExpr::Max(self)).into() } pub fn any_value(self: ExprRef, ignore_nulls: bool) -> ExprRef { - Expr::Agg(AggExpr::AnyValue(self, ignore_nulls)).into() + Self::Agg(AggExpr::AnyValue(self, ignore_nulls)).into() } pub fn agg_list(self: ExprRef) -> ExprRef { - Expr::Agg(AggExpr::List(self)).into() + Self::Agg(AggExpr::List(self)).into() } pub fn agg_concat(self: ExprRef) -> ExprRef { - Expr::Agg(AggExpr::Concat(self)).into() + Self::Agg(AggExpr::Concat(self)).into() } #[allow(clippy::should_implement_trait)] pub fn not(self: ExprRef) -> ExprRef { - Expr::Not(self).into() + Self::Not(self).into() } pub fn is_null(self: ExprRef) -> ExprRef { - Expr::IsNull(self).into() + Self::IsNull(self).into() } pub fn not_null(self: ExprRef) -> ExprRef { - Expr::NotNull(self).into() + Self::NotNull(self).into() } pub fn fill_null(self: ExprRef, fill_value: ExprRef) -> ExprRef { - Expr::FillNull(self, fill_value).into() + Self::FillNull(self, fill_value).into() } pub fn is_in(self: ExprRef, items: ExprRef) -> ExprRef { - Expr::IsIn(self, items).into() + Self::IsIn(self, items).into() } pub fn between(self: ExprRef, lower: ExprRef, upper: ExprRef) -> ExprRef { - Expr::Between(self, lower, upper).into() + Self::Between(self, lower, upper).into() } pub fn eq(self: ExprRef, other: ExprRef) -> ExprRef { @@ -678,7 +678,7 @@ impl Expr { } } - pub fn with_new_children(&self, children: Vec) -> Expr { + pub fn with_new_children(&self, children: Vec) -> Self { use Expr::*; match self { // no children @@ -885,8 +885,8 @@ impl Expr { ))); } match predicate.as_ref() { - Expr::Literal(lit::LiteralValue::Boolean(true)) => if_true.to_field(schema), - Expr::Literal(lit::LiteralValue::Boolean(false)) => { + Self::Literal(lit::LiteralValue::Boolean(true)) => if_true.to_field(schema), + Self::Literal(lit::LiteralValue::Boolean(false)) => { Ok(if_false.to_field(schema)?.rename(if_true.name())) } _ => { @@ -1030,7 +1030,7 @@ impl Expr { /// If the expression is a literal, return it. Otherwise, return None. pub fn as_literal(&self) -> Option<&lit::LiteralValue> { match self { - Expr::Literal(lit) => Some(lit), + Self::Literal(lit) => Some(lit), _ => None, } } diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index d4690e3ec6..378611851a 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -26,8 +26,8 @@ impl PythonUDF { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { match self { - PythonUDF::Stateless(stateless_python_udf) => stateless_python_udf, - PythonUDF::Stateful(stateful_python_udf) => stateful_python_udf, + Self::Stateless(stateless_python_udf) => stateless_python_udf, + Self::Stateful(stateful_python_udf) => stateful_python_udf, } } } diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 065c2a4d4e..5db0f05a3d 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -241,7 +241,7 @@ impl LiteralValue { /// If the liter is a boolean, return it. Otherwise, return None. pub fn as_bool(&self) -> Option { match self { - LiteralValue::Boolean(b) => Some(*b), + Self::Boolean(b) => Some(*b), _ => None, } } @@ -249,56 +249,56 @@ impl LiteralValue { /// If the literal is a string, return it. Otherwise, return None. pub fn as_str(&self) -> Option<&str> { match self { - LiteralValue::Utf8(s) => Some(s), + Self::Utf8(s) => Some(s), _ => None, } } /// If the literal is `Binary`, return it. Otherwise, return None. pub fn as_binary(&self) -> Option<&[u8]> { match self { - LiteralValue::Binary(b) => Some(b), + Self::Binary(b) => Some(b), _ => None, } } /// If the literal is `Int32`, return it. Otherwise, return None. pub fn as_i32(&self) -> Option { match self { - LiteralValue::Int32(i) => Some(*i), + Self::Int32(i) => Some(*i), _ => None, } } /// If the literal is `UInt32`, return it. Otherwise, return None. pub fn as_u32(&self) -> Option { match self { - LiteralValue::UInt32(i) => Some(*i), + Self::UInt32(i) => Some(*i), _ => None, } } /// If the literal is `Int64`, return it. Otherwise, return None. pub fn as_i64(&self) -> Option { match self { - LiteralValue::Int64(i) => Some(*i), + Self::Int64(i) => Some(*i), _ => None, } } /// If the literal is `UInt64`, return it. Otherwise, return None. pub fn as_u64(&self) -> Option { match self { - LiteralValue::UInt64(i) => Some(*i), + Self::UInt64(i) => Some(*i), _ => None, } } /// If the literal is `Float64`, return it. Otherwise, return None. pub fn as_f64(&self) -> Option { match self { - LiteralValue::Float64(f) => Some(*f), + Self::Float64(f) => Some(*f), _ => None, } } /// If the literal is a series, return it. Otherwise, return None. pub fn as_series(&self) -> Option<&Series> { match self { - LiteralValue::Series(series) => Some(series), + Self::Series(series) => Some(series), _ => None, } } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index af56dc68d8..a4a54e74c9 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -651,13 +651,13 @@ impl_bincode_py_state_serialization!(PyExpr); impl From for PyExpr { fn from(value: crate::ExprRef) -> Self { - PyExpr { expr: value } + Self { expr: value } } } impl From for PyExpr { fn from(value: crate::Expr) -> Self { - PyExpr { + Self { expr: Arc::new(value), } } diff --git a/src/daft-functions-json/Cargo.toml b/src/daft-functions-json/Cargo.toml index e27077b497..1f7547bcc3 100644 --- a/src/daft-functions-json/Cargo.toml +++ b/src/daft-functions-json/Cargo.toml @@ -21,6 +21,9 @@ python = [ "daft-dsl/python" ] +[lints] +workspace = true + [package] name = "daft-functions-json" edition.workspace = true diff --git a/src/daft-functions/Cargo.toml b/src/daft-functions/Cargo.toml index 9be2a86dc2..febb241e13 100644 --- a/src/daft-functions/Cargo.toml +++ b/src/daft-functions/Cargo.toml @@ -30,6 +30,9 @@ python = [ "common-io-config/python" ] +[lints] +workspace = true + [package] name = "daft-functions" edition.workspace = true diff --git a/src/daft-functions/src/distance/cosine.rs b/src/daft-functions/src/distance/cosine.rs index c7065ff655..170587c1bb 100644 --- a/src/daft-functions/src/distance/cosine.rs +++ b/src/daft-functions/src/distance/cosine.rs @@ -14,9 +14,9 @@ trait SpatialSimilarity { impl SpatialSimilarity for f64 { fn cosine(a: &[Self], b: &[Self]) -> Option { - let xy = a.iter().zip(b).map(|(a, b)| a * b).sum::(); - let x_sq = a.iter().map(|x| x.powi(2)).sum::().sqrt(); - let y_sq = b.iter().map(|x| x.powi(2)).sum::().sqrt(); + let xy = a.iter().zip(b).map(|(a, b)| a * b).sum::(); + let x_sq = a.iter().map(|x| x.powi(2)).sum::().sqrt(); + let y_sq = b.iter().map(|x| x.powi(2)).sum::().sqrt(); Some(1.0 - xy / (x_sq * y_sq)) } } diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index 0a8486864e..a8b1c8d0cd 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -60,13 +60,13 @@ pub enum Error { } impl From for std::io::Error { - fn from(err: Error) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::Other, err) + fn from(err: Error) -> Self { + Self::new(std::io::ErrorKind::Other, err) } } impl From for DaftError { - fn from(err: Error) -> DaftError { - DaftError::External(err.into()) + fn from(err: Error) -> Self { + Self::External(err.into()) } } diff --git a/src/daft-functions/src/tokenize/bpe.rs b/src/daft-functions/src/tokenize/bpe.rs index 18307092b4..c35e41771e 100644 --- a/src/daft-functions/src/tokenize/bpe.rs +++ b/src/daft-functions/src/tokenize/bpe.rs @@ -62,16 +62,16 @@ impl From for DaftError { fn from(err: Error) -> Self { use Error::*; match err { - Base64Decode { .. } => DaftError::ValueError(err.to_string()), - RankNumberParse { .. } => DaftError::ValueError(err.to_string()), - InvalidUtf8Sequence { .. } => DaftError::ValueError(err.to_string()), - InvalidTokenLine { .. } => DaftError::ValueError(err.to_string()), - EmptyTokenFile {} => DaftError::ValueError(err.to_string()), - BPECreation { .. } => DaftError::ComputeError(err.to_string()), - BadToken { .. } => DaftError::ValueError(err.to_string()), - Decode { .. } => DaftError::ComputeError(err.to_string()), - MissingPattern {} => DaftError::ValueError(err.to_string()), - UnsupportedSpecialTokens { .. } => DaftError::ValueError(err.to_string()), + Base64Decode { .. } => Self::ValueError(err.to_string()), + RankNumberParse { .. } => Self::ValueError(err.to_string()), + InvalidUtf8Sequence { .. } => Self::ValueError(err.to_string()), + InvalidTokenLine { .. } => Self::ValueError(err.to_string()), + EmptyTokenFile {} => Self::ValueError(err.to_string()), + BPECreation { .. } => Self::ComputeError(err.to_string()), + BadToken { .. } => Self::ValueError(err.to_string()), + Decode { .. } => Self::ComputeError(err.to_string()), + MissingPattern {} => Self::ValueError(err.to_string()), + UnsupportedSpecialTokens { .. } => Self::ValueError(err.to_string()), } } } diff --git a/src/daft-functions/src/uri/download.rs b/src/daft-functions/src/uri/download.rs index 4bacbdb3a9..9f107e95c1 100644 --- a/src/daft-functions/src/uri/download.rs +++ b/src/daft-functions/src/uri/download.rs @@ -29,7 +29,7 @@ impl ScalarUDF for DownloadFunction { } fn evaluate(&self, inputs: &[Series]) -> DaftResult { - let DownloadFunction { + let Self { max_connections, raise_error_on_failure, multi_thread, diff --git a/src/daft-functions/src/uri/upload.rs b/src/daft-functions/src/uri/upload.rs index 14f7a3721c..d4c606955f 100644 --- a/src/daft-functions/src/uri/upload.rs +++ b/src/daft-functions/src/uri/upload.rs @@ -26,7 +26,7 @@ impl ScalarUDF for UploadFunction { } fn evaluate(&self, inputs: &[Series]) -> DaftResult { - let UploadFunction { + let Self { location, config, max_connections, diff --git a/src/daft-image/Cargo.toml b/src/daft-image/Cargo.toml index 6dab7b3b8f..205ad8b9a0 100644 --- a/src/daft-image/Cargo.toml +++ b/src/daft-image/Cargo.toml @@ -18,6 +18,9 @@ python = [ "common-error/python" ] +[lints] +workspace = true + [package] name = "daft-image" edition.workspace = true diff --git a/src/daft-image/src/image_buffer.rs b/src/daft-image/src/image_buffer.rs index 18db67e0df..f1595aaf1f 100644 --- a/src/daft-image/src/image_buffer.rs +++ b/src/daft-image/src/image_buffer.rs @@ -45,12 +45,7 @@ macro_rules! with_method_on_image_buffer { } impl<'a> DaftImageBuffer<'a> { - pub fn from_raw( - mode: &ImageMode, - width: u32, - height: u32, - data: Cow<'a, [u8]>, - ) -> DaftImageBuffer<'a> { + pub fn from_raw(mode: &ImageMode, width: u32, height: u32, data: Cow<'a, [u8]>) -> Self { use DaftImageBuffer::*; match mode { ImageMode::L => L(ImageBuffer::from_raw(width, height, data).unwrap()), diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 42cc984204..433090370e 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -54,6 +54,9 @@ python = [ "common-file-formats/python" ] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-io" diff --git a/src/daft-io/src/azure_blob.rs b/src/daft-io/src/azure_blob.rs index de2d05bd5d..a52092bd4e 100644 --- a/src/daft-io/src/azure_blob.rs +++ b/src/daft-io/src/azure_blob.rs @@ -110,27 +110,27 @@ impl From for super::Error { match error { UnableToReadBytes { path, source } | UnableToOpenFile { path, source } => { match source.as_http_error().map(|v| v.status().into()) { - Some(404) | Some(410) => super::Error::NotFound { + Some(404) | Some(410) => Self::NotFound { path, source: source.into(), }, - Some(401) => super::Error::Unauthorized { + Some(401) => Self::Unauthorized { store: super::SourceType::AzureBlob, path, source: source.into(), }, - None | Some(_) => super::Error::UnableToOpenFile { + None | Some(_) => Self::UnableToOpenFile { path, source: source.into(), }, } } - NotFound { ref path } => super::Error::NotFound { + NotFound { ref path } => Self::NotFound { path: path.into(), source: error.into(), }, - NotAFile { path } => super::Error::NotAFile { path }, - _ => super::Error::Generic { + NotAFile { path } => Self::NotAFile { path }, + _ => Self::Generic { store: super::SourceType::AzureBlob, source: error.into(), }, @@ -225,7 +225,7 @@ impl AzureBlobSource { BlobServiceClient::new(storage_account, storage_credentials) }; - Ok(AzureBlobSource { + Ok(Self { blob_client: blob_client.into(), } .into()) diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index 5435fedeb2..fe399ab3ec 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -58,50 +58,50 @@ impl From for super::Error { | UnableToOpenFile { path, source } | UnableToListObjects { path, source } => match source { GError::HttpClient(err) => match err.status().map(|s| s.as_u16()) { - Some(404) | Some(410) => super::Error::NotFound { + Some(404) | Some(410) => Self::NotFound { path, source: err.into(), }, - Some(401) => super::Error::Unauthorized { + Some(401) => Self::Unauthorized { store: super::SourceType::GCS, path, source: err.into(), }, - _ => super::Error::UnableToOpenFile { + _ => Self::UnableToOpenFile { path, source: err.into(), }, }, GError::Response(err) => match err.code { - 404 | 410 => super::Error::NotFound { + 404 | 410 => Self::NotFound { path, source: err.into(), }, - 401 => super::Error::Unauthorized { + 401 => Self::Unauthorized { store: super::SourceType::GCS, path, source: err.into(), }, - _ => super::Error::UnableToOpenFile { + _ => Self::UnableToOpenFile { path, source: err.into(), }, }, - GError::TokenSource(err) => super::Error::UnableToLoadCredentials { + GError::TokenSource(err) => Self::UnableToLoadCredentials { store: super::SourceType::GCS, source: err, }, }, - NotFound { ref path } => super::Error::NotFound { + NotFound { ref path } => Self::NotFound { path: path.into(), source: error.into(), }, - InvalidUrl { path, source } => super::Error::InvalidUrl { path, source }, - UnableToLoadCredentials { source } => super::Error::UnableToLoadCredentials { + InvalidUrl { path, source } => Self::InvalidUrl { path, source }, + UnableToLoadCredentials { source } => Self::UnableToLoadCredentials { store: super::SourceType::GCS, source: source.into(), }, - NotAFile { path } => super::Error::NotAFile { path }, + NotAFile { path } => Self::NotAFile { path }, } } } @@ -392,7 +392,7 @@ impl GCSSource { } let client = Client::new(client_config); - Ok(GCSSource { + Ok(Self { client: GCSClientWrapper(client), } .into()) diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index 1f3a3bef11..14571fd79f 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -147,17 +147,17 @@ impl From for super::Error { use Error::*; match error { UnableToOpenFile { path, source } => match source.status().map(|v| v.as_u16()) { - Some(404) | Some(410) => super::Error::NotFound { + Some(404) | Some(410) => Self::NotFound { path, source: source.into(), }, - None | Some(_) => super::Error::UnableToOpenFile { + None | Some(_) => Self::UnableToOpenFile { path, source: source.into(), }, }, - UnableToDetermineSize { path } => super::Error::UnableToDetermineSize { path }, - _ => super::Error::Generic { + UnableToDetermineSize { path } => Self::UnableToDetermineSize { path }, + _ => Self::Generic { store: super::SourceType::Http, source: error.into(), }, @@ -174,7 +174,7 @@ impl HttpSource { .context(UnableToCreateHeaderSnafu)?, ); - Ok(HttpSource { + Ok(Self { client: reqwest::ClientBuilder::default() .pool_max_idle_per_host(70) .default_headers(default_headers) diff --git a/src/daft-io/src/huggingface.rs b/src/daft-io/src/huggingface.rs index 3aacea186d..f10f2d8de3 100644 --- a/src/daft-io/src/huggingface.rs +++ b/src/daft-io/src/huggingface.rs @@ -128,7 +128,7 @@ impl FromStr for HFPathParts { let (repository, uri) = if let Some((repo, uri)) = uri.split_once('/') { (repo, uri) } else { - return Some(HFPathParts { + return Some(Self { bucket: bucket.to_string(), repository: format!("{}/{}", username, uri), revision: "main".to_string(), @@ -150,7 +150,7 @@ impl FromStr for HFPathParts { // ^--------------^ let path = uri.to_string().trim_end_matches('/').to_string(); - Some(HFPathParts { + Some(Self { bucket: bucket.to_string(), repository, revision, @@ -221,17 +221,17 @@ impl From for super::Error { use Error::*; match error { UnableToOpenFile { path, source } => match source.status().map(|v| v.as_u16()) { - Some(404) | Some(410) => super::Error::NotFound { + Some(404) | Some(410) => Self::NotFound { path, source: source.into(), }, - None | Some(_) => super::Error::UnableToOpenFile { + None | Some(_) => Self::UnableToOpenFile { path, source: source.into(), }, }, - UnableToDetermineSize { path } => super::Error::UnableToDetermineSize { path }, - _ => super::Error::Generic { + UnableToDetermineSize { path } => Self::UnableToDetermineSize { path }, + _ => Self::Generic { store: super::SourceType::Http, source: error.into(), }, @@ -256,7 +256,7 @@ impl HFSource { ); } - Ok(HFSource { + Ok(Self { http_source: HttpSource { client: reqwest::ClientBuilder::default() .pool_max_idle_per_host(70) diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index f5c5834541..6fdaac2368 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -142,34 +142,34 @@ pub enum Error { } impl From for DaftError { - fn from(err: Error) -> DaftError { + fn from(err: Error) -> Self { use Error::*; match err { - NotFound { path, source } => DaftError::FileNotFound { path, source }, - ConnectTimeout { .. } => DaftError::ConnectTimeout(err.into()), - ReadTimeout { .. } => DaftError::ReadTimeout(err.into()), - UnableToReadBytes { .. } => DaftError::ByteStreamError(err.into()), - SocketError { .. } => DaftError::SocketError(err.into()), + NotFound { path, source } => Self::FileNotFound { path, source }, + ConnectTimeout { .. } => Self::ConnectTimeout(err.into()), + ReadTimeout { .. } => Self::ReadTimeout(err.into()), + UnableToReadBytes { .. } => Self::ByteStreamError(err.into()), + SocketError { .. } => Self::SocketError(err.into()), // We have to repeat everything above for the case we have an Arc since we can't move the error. CachedError { ref source } => match source.as_ref() { - NotFound { path, source: _ } => DaftError::FileNotFound { + NotFound { path, source: _ } => Self::FileNotFound { path: path.clone(), source: err.into(), }, - ConnectTimeout { .. } => DaftError::ConnectTimeout(err.into()), - ReadTimeout { .. } => DaftError::ReadTimeout(err.into()), - UnableToReadBytes { .. } => DaftError::ByteStreamError(err.into()), - SocketError { .. } => DaftError::SocketError(err.into()), - _ => DaftError::External(err.into()), + ConnectTimeout { .. } => Self::ConnectTimeout(err.into()), + ReadTimeout { .. } => Self::ReadTimeout(err.into()), + UnableToReadBytes { .. } => Self::ByteStreamError(err.into()), + SocketError { .. } => Self::SocketError(err.into()), + _ => Self::External(err.into()), }, - _ => DaftError::External(err.into()), + _ => Self::External(err.into()), } } } impl From for std::io::Error { - fn from(err: Error) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::Other, err) + fn from(err: Error) -> Self { + Self::new(std::io::ErrorKind::Other, err) } } @@ -183,7 +183,7 @@ pub struct IOClient { impl IOClient { pub fn new(config: Arc) -> Result { - Ok(IOClient { + Ok(Self { source_type_to_store: tokio::sync::RwLock::new(HashMap::new()), config, }) @@ -361,12 +361,12 @@ pub enum SourceType { impl std::fmt::Display for SourceType { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - SourceType::File => write!(f, "file"), - SourceType::Http => write!(f, "http"), - SourceType::S3 => write!(f, "s3"), - SourceType::AzureBlob => write!(f, "AzureBlob"), - SourceType::GCS => write!(f, "gcs"), - SourceType::HF => write!(f, "hf"), + Self::File => write!(f, "file"), + Self::Http => write!(f, "http"), + Self::S3 => write!(f, "s3"), + Self::AzureBlob => write!(f, "AzureBlob"), + Self::GCS => write!(f, "gcs"), + Self::HF => write!(f, "hf"), } } } diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index fb32d65b27..4ed9eaa54b 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -87,11 +87,11 @@ impl From for super::Error { UnableToOpenFile { path, source } | UnableToFetchDirectoryEntries { path, source } => { use std::io::ErrorKind::*; match source.kind() { - NotFound => super::Error::NotFound { + NotFound => Self::NotFound { path, source: source.into(), }, - _ => super::Error::UnableToOpenFile { + _ => Self::UnableToOpenFile { path, source: source.into(), }, @@ -100,21 +100,21 @@ impl From for super::Error { UnableToFetchFileMetadata { path, source } => { use std::io::ErrorKind::*; match source.kind() { - NotFound | IsADirectory => super::Error::NotFound { + NotFound | IsADirectory => Self::NotFound { path, source: source.into(), }, - _ => super::Error::UnableToOpenFile { + _ => Self::UnableToOpenFile { path, source: source.into(), }, } } - UnableToReadBytes { path, source } => super::Error::UnableToReadBytes { path, source }, + UnableToReadBytes { path, source } => Self::UnableToReadBytes { path, source }, UnableToWriteToFile { path, source } | UnableToOpenFileForWriting { path, source } => { - super::Error::UnableToWriteToFile { path, source } + Self::UnableToWriteToFile { path, source } } - _ => super::Error::Generic { + _ => Self::Generic { store: super::SourceType::File, source: error.into(), }, @@ -124,7 +124,7 @@ impl From for super::Error { impl LocalSource { pub async fn get_client() -> super::Result> { - Ok(LocalSource {}.into()) + Ok(Self {}.into()) } } diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index d3fa97601a..32bf328f17 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -109,7 +109,7 @@ impl GetResult { .source .get(&rp.input, rp.range.clone(), rp.io_stats.clone()) .await?; - if let GetResult::Stream(stream, size, permit, _) = get_result { + if let Self::Stream(stream, size, permit, _) = get_result { result = collect_bytes(stream, size, permit).await; } else { unreachable!("Retrying a stream should always be a stream"); @@ -125,9 +125,9 @@ impl GetResult { pub fn with_retry(self, params: StreamingRetryParams) -> Self { match self { - GetResult::File(..) => self, - GetResult::Stream(s, size, permit, _) => { - GetResult::Stream(s, size, permit, Some(Box::new(params))) + Self::File(..) => self, + Self::Stream(s, size, permit, _) => { + Self::Stream(s, size, permit, Some(Box::new(params))) } } } diff --git a/src/daft-io/src/object_store_glob.rs b/src/daft-io/src/object_store_glob.rs index 58261d5fbf..13b43f773c 100644 --- a/src/daft-io/src/object_store_glob.rs +++ b/src/daft-io/src/object_store_glob.rs @@ -58,7 +58,7 @@ impl GlobState { } pub fn advance(self, path: String, idx: usize, fanout_factor: usize) -> Self { - GlobState { + Self { current_path: path, current_fragment_idx: idx, current_fanout: self.current_fanout * fanout_factor, @@ -67,7 +67,7 @@ impl GlobState { } pub fn with_wildcard_mode(self) -> Self { - GlobState { + Self { wildcard_mode: true, ..self } @@ -126,7 +126,7 @@ impl GlobFragment { } } - GlobFragment { + Self { data: data.to_string(), first_wildcard_idx, escaped_data, @@ -139,11 +139,11 @@ impl GlobFragment { } /// Joins a slice of GlobFragments together with a separator - pub fn join(fragments: &[GlobFragment], sep: &str) -> Self { - GlobFragment::new( + pub fn join(fragments: &[Self], sep: &str) -> Self { + Self::new( fragments .iter() - .map(|frag: &GlobFragment| frag.data.as_str()) + .map(|frag: &Self| frag.data.as_str()) .join(sep) .as_str(), ) diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 97c9641e58..2766011ae7 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -124,120 +124,120 @@ impl From for super::Error { match error { UnableToOpenFile { path, source } => match source { - SdkError::TimeoutError(_) => super::Error::ReadTimeout { + SdkError::TimeoutError(_) => Self::ReadTimeout { path, source: source.into(), }, SdkError::DispatchFailure(ref dispatch) => { if dispatch.is_timeout() { - super::Error::ConnectTimeout { + Self::ConnectTimeout { path, source: source.into(), } } else if dispatch.is_io() { - super::Error::SocketError { + Self::SocketError { path, source: source.into(), } } else { - super::Error::UnableToOpenFile { + Self::UnableToOpenFile { path, source: source.into(), } } } _ => match source.into_service_error() { - GetObjectError::NoSuchKey(no_such_key) => super::Error::NotFound { + GetObjectError::NoSuchKey(no_such_key) => Self::NotFound { path, source: no_such_key.into(), }, - GetObjectError::Unhandled(v) => super::Error::Unhandled { + GetObjectError::Unhandled(v) => Self::Unhandled { path, msg: DisplayErrorContext(v).to_string(), }, - err => super::Error::UnableToOpenFile { + err => Self::UnableToOpenFile { path, source: err.into(), }, }, }, UnableToHeadFile { path, source } => match source { - SdkError::TimeoutError(_) => super::Error::ReadTimeout { + SdkError::TimeoutError(_) => Self::ReadTimeout { path, source: source.into(), }, SdkError::DispatchFailure(ref dispatch) => { if dispatch.is_timeout() { - super::Error::ConnectTimeout { + Self::ConnectTimeout { path, source: source.into(), } } else if dispatch.is_io() { - super::Error::SocketError { + Self::SocketError { path, source: source.into(), } } else { - super::Error::UnableToOpenFile { + Self::UnableToOpenFile { path, source: source.into(), } } } _ => match source.into_service_error() { - HeadObjectError::NotFound(no_such_key) => super::Error::NotFound { + HeadObjectError::NotFound(no_such_key) => Self::NotFound { path, source: no_such_key.into(), }, - HeadObjectError::Unhandled(v) => super::Error::Unhandled { + HeadObjectError::Unhandled(v) => Self::Unhandled { path, msg: DisplayErrorContext(v).to_string(), }, - err => super::Error::UnableToOpenFile { + err => Self::UnableToOpenFile { path, source: err.into(), }, }, }, UnableToListObjects { path, source } => match source { - SdkError::TimeoutError(_) => super::Error::ReadTimeout { + SdkError::TimeoutError(_) => Self::ReadTimeout { path, source: source.into(), }, SdkError::DispatchFailure(ref dispatch) => { if dispatch.is_timeout() { - super::Error::ConnectTimeout { + Self::ConnectTimeout { path, source: source.into(), } } else if dispatch.is_io() { - super::Error::SocketError { + Self::SocketError { path, source: source.into(), } } else { - super::Error::UnableToOpenFile { + Self::UnableToOpenFile { path, source: source.into(), } } } _ => match source.into_service_error() { - ListObjectsV2Error::NoSuchBucket(no_such_key) => super::Error::NotFound { + ListObjectsV2Error::NoSuchBucket(no_such_key) => Self::NotFound { path, source: no_such_key.into(), }, - ListObjectsV2Error::Unhandled(v) => super::Error::Unhandled { + ListObjectsV2Error::Unhandled(v) => Self::Unhandled { path, msg: DisplayErrorContext(v).to_string(), }, - err => super::Error::UnableToOpenFile { + err => Self::UnableToOpenFile { path, source: err.into(), }, }, }, - InvalidUrl { path, source } => super::Error::InvalidUrl { path, source }, + InvalidUrl { path, source } => Self::InvalidUrl { path, source }, UnableToReadBytes { path, source } => { use std::error::Error; let io_error = if let Some(source) = source.source() { @@ -247,21 +247,21 @@ impl From for super::Error { } else { std::io::Error::new(io::ErrorKind::Other, source) }; - super::Error::UnableToReadBytes { + Self::UnableToReadBytes { path, source: io_error, } } - NotAFile { path } => super::Error::NotAFile { path }, - UnableToLoadCredentials { source } => super::Error::UnableToLoadCredentials { + NotAFile { path } => Self::NotAFile { path }, + UnableToLoadCredentials { source } => Self::UnableToLoadCredentials { store: SourceType::S3, source: source.into(), }, - NotFound { ref path } => super::Error::NotFound { + NotFound { ref path } => Self::NotFound { path: path.into(), source: error.into(), }, - err => super::Error::Generic { + err => Self::Generic { store: SourceType::S3, source: err.into(), }, @@ -553,7 +553,7 @@ async fn build_client(config: &S3Config) -> super::Result { const REGION_HEADER: &str = "x-amz-bucket-region"; impl S3LikeSource { - pub async fn get_client(config: &S3Config) -> super::Result> { + pub async fn get_client(config: &S3Config) -> super::Result> { Ok(build_client(config).await?.into()) } diff --git a/src/daft-io/src/stats.rs b/src/daft-io/src/stats.rs index 0fc15be4d3..32aabd1b90 100644 --- a/src/daft-io/src/stats.rs +++ b/src/daft-io/src/stats.rs @@ -49,7 +49,7 @@ pub(crate) struct IOStatsByteStreamContextHandle { impl IOStatsContext { pub fn new>>(name: S) -> IOStatsRef { - Arc::new(IOStatsContext { + Arc::new(Self { name: name.into(), num_get_requests: atomic::AtomicUsize::new(0), num_head_requests: atomic::AtomicUsize::new(0), diff --git a/src/daft-json/Cargo.toml b/src/daft-json/Cargo.toml index 241156403c..1cf8308a7a 100644 --- a/src/daft-json/Cargo.toml +++ b/src/daft-json/Cargo.toml @@ -38,6 +38,9 @@ python = [ "daft-dsl/python" ] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-json" diff --git a/src/daft-json/src/lib.rs b/src/daft-json/src/lib.rs index 1c56ba7167..6f935b8e4c 100644 --- a/src/daft-json/src/lib.rs +++ b/src/daft-json/src/lib.rs @@ -50,17 +50,17 @@ pub enum Error { } impl From for DaftError { - fn from(err: Error) -> DaftError { + fn from(err: Error) -> Self { match err { Error::IOError { source } => source.into(), - _ => DaftError::External(err.into()), + _ => Self::External(err.into()), } } } impl From for Error { fn from(err: daft_io::Error) -> Self { - Error::IOError { source: err } + Self::IOError { source: err } } } diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 2dac516672..cd061c1c35 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -29,6 +29,9 @@ tracing = {workspace = true} [features] python = ["dep:pyo3", "common-daft-config/python", "common-file-formats/python", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python", "daft-scan/python", "common-display/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-local-execution" diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index bb22b9d4ea..4bc6fd1f5c 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -91,8 +91,8 @@ pub enum PipelineReceiver { impl PipelineReceiver { pub async fn recv(&mut self) -> Option { match self { - PipelineReceiver::InOrder(rr) => rr.recv().await, - PipelineReceiver::OutOfOrder(r) => r.recv().await, + Self::InOrder(rr) => rr.recv().await, + Self::OutOfOrder(r) => r.recv().await, } } } diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 14bc949ff0..13e79b5ede 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -19,15 +19,15 @@ enum AntiSemiProbeState { impl AntiSemiProbeState { fn set_table(&mut self, table: &Arc) { - if let AntiSemiProbeState::Building = self { - *self = AntiSemiProbeState::ReadyToProbe(table.clone()); + if let Self::Building = self { + *self = Self::ReadyToProbe(table.clone()); } else { panic!("AntiSemiProbeState should only be in Building state when setting table") } } fn get_probeable(&self) -> &Arc { - if let AntiSemiProbeState::ReadyToProbe(probeable) = self { + if let Self::ReadyToProbe(probeable) = self { probeable } else { panic!("AntiSemiProbeState should only be in ReadyToProbe state when getting probeable") diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index f849064964..0a037dc6bb 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -21,15 +21,15 @@ enum HashJoinProbeState { impl HashJoinProbeState { fn set_table(&mut self, table: &Arc, tables: &Arc>) { - if let HashJoinProbeState::Building = self { - *self = HashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); + if let Self::Building = self { + *self = Self::ReadyToProbe(table.clone(), tables.clone()); } else { panic!("HashJoinProbeState should only be in Building state when setting table") } } fn get_probeable_and_table(&self) -> (&Arc, &Arc>) { - if let HashJoinProbeState::ReadyToProbe(probe_table, tables) = self { + if let Self::ReadyToProbe(probe_table, tables) = self { (probe_table, tables) } else { panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 732f306768..1356689a08 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -76,7 +76,7 @@ pub enum Error { } impl From for DaftError { - fn from(err: Error) -> DaftError { + fn from(err: Error) -> Self { match err { Error::PipelineCreationError { source, plan_name } => { log::error!("Error creating pipeline from {}", plan_name); @@ -86,7 +86,7 @@ impl From for DaftError { log::error!("Error when running pipeline node {}", node_name); source } - _ => DaftError::External(err.into()), + _ => Self::External(err.into()), } } } diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 62935dfff0..9f7da9b915 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -42,33 +42,33 @@ pub enum PipelineResultType { impl From> for PipelineResultType { fn from(data: Arc) -> Self { - PipelineResultType::Data(data) + Self::Data(data) } } impl From<(Arc, Arc>)> for PipelineResultType { fn from((probe_table, tables): (Arc, Arc>)) -> Self { - PipelineResultType::ProbeTable(probe_table, tables) + Self::ProbeTable(probe_table, tables) } } impl PipelineResultType { pub fn as_data(&self) -> &Arc { match self { - PipelineResultType::Data(data) => data, + Self::Data(data) => data, _ => panic!("Expected data"), } } pub fn as_probe_table(&self) -> (&Arc, &Arc>) { match self { - PipelineResultType::ProbeTable(probe_table, tables) => (probe_table, tables), + Self::ProbeTable(probe_table, tables) => (probe_table, tables), _ => panic!("Expected probe table"), } } pub fn should_broadcast(&self) -> bool { - matches!(self, PipelineResultType::ProbeTable(_, _)) + matches!(self, Self::ProbeTable(_, _)) } } diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 09e42ae81f..8894db503d 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -34,7 +34,7 @@ pub(crate) struct BlockingSinkNode { impl BlockingSinkNode { pub(crate) fn new(op: Box, child: Box) -> Self { let name = op.name(); - BlockingSinkNode { + Self { op: Arc::new(tokio::sync::Mutex::new(op)), name, child, diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 1804a3e07e..5b188c4ad8 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -38,7 +38,7 @@ pub(crate) struct StreamingSinkNode { impl StreamingSinkNode { pub(crate) fn new(op: Box, children: Vec>) -> Self { let name = op.name(); - StreamingSinkNode { + Self { op: Arc::new(tokio::sync::Mutex::new(op)), name, children, diff --git a/src/daft-micropartition/Cargo.toml b/src/daft-micropartition/Cargo.toml index ee2322d8c9..2b8a405ef8 100644 --- a/src/daft-micropartition/Cargo.toml +++ b/src/daft-micropartition/Cargo.toml @@ -19,6 +19,9 @@ snafu = {workspace = true} [features] python = ["dep:pyo3", "common-error/python", "common-file-formats/python", "daft-core/python", "daft-dsl/python", "daft-table/python", "daft-io/python", "daft-parquet/python", "daft-scan/python", "daft-stats/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-micropartition" diff --git a/src/daft-micropartition/src/lib.rs b/src/daft-micropartition/src/lib.rs index 1933ab3b78..1a01f4e933 100644 --- a/src/daft-micropartition/src/lib.rs +++ b/src/daft-micropartition/src/lib.rs @@ -47,7 +47,7 @@ impl From for DaftError { fn from(value: Error) -> Self { match value { Error::DaftCoreCompute { source } => source, - _ => DaftError::External(value.into()), + _ => Self::External(value.into()), } } } diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index c3059626fa..5b518419d1 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -38,7 +38,7 @@ pub(crate) enum TableState { impl Display for TableState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - TableState::Unloaded(scan_task) => { + Self::Unloaded(scan_task) => { write!( f, "TableState: Unloaded. To load from: {:#?}", @@ -49,7 +49,7 @@ impl Display for TableState { .collect::>() ) } - TableState::Loaded(tables) => { + Self::Loaded(tables) => { writeln!(f, "TableState: Loaded. {} tables", tables.len())?; for tab in tables.iter() { writeln!(f, "{}", tab)?; @@ -524,7 +524,7 @@ impl MicroPartition { let statistics = statistics .cast_to_schema_with_fill(schema.clone(), fill_map.as_ref()) .expect("Statistics cannot be casted to schema"); - MicroPartition { + Self { schema, state: Mutex::new(TableState::Unloaded(scan_task)), metadata, @@ -557,7 +557,7 @@ impl MicroPartition { }); let tables_len_sum = tables.iter().map(|t| t.len()).sum(); - MicroPartition { + Self { schema, state: Mutex::new(TableState::Loaded(tables)), metadata: TableMetadata { diff --git a/src/daft-micropartition/src/ops/agg.rs b/src/daft-micropartition/src/ops/agg.rs index 108ea3e09d..18f265c3eb 100644 --- a/src/daft-micropartition/src/ops/agg.rs +++ b/src/daft-micropartition/src/ops/agg.rs @@ -15,7 +15,7 @@ impl MicroPartition { [] => { let empty_table = Table::empty(Some(self.schema.clone()))?; let agged = empty_table.agg(to_agg, group_by)?; - Ok(MicroPartition::new_loaded( + Ok(Self::new_loaded( agged.schema.clone(), vec![agged].into(), None, @@ -23,7 +23,7 @@ impl MicroPartition { } [t] => { let agged = t.agg(to_agg, group_by)?; - Ok(MicroPartition::new_loaded( + Ok(Self::new_loaded( agged.schema.clone(), vec![agged].into(), None, diff --git a/src/daft-micropartition/src/ops/cast_to_schema.rs b/src/daft-micropartition/src/ops/cast_to_schema.rs index e83c0774e1..1612a83eae 100644 --- a/src/daft-micropartition/src/ops/cast_to_schema.rs +++ b/src/daft-micropartition/src/ops/cast_to_schema.rs @@ -30,14 +30,14 @@ impl MicroPartition { scan_task.pushdowns.clone(), )) }; - Ok(MicroPartition::new_unloaded( + Ok(Self::new_unloaded( maybe_new_scan_task, self.metadata.clone(), pruned_statistics.expect("Unloaded MicroPartition should have statistics"), )) } // If Tables are already loaded, we map `Table::cast_to_schema` on each Table - TableState::Loaded(tables) => Ok(MicroPartition::new_loaded( + TableState::Loaded(tables) => Ok(Self::new_loaded( schema.clone(), Arc::new( tables diff --git a/src/daft-micropartition/src/ops/concat.rs b/src/daft-micropartition/src/ops/concat.rs index 904e68324a..682f75f4ce 100644 --- a/src/daft-micropartition/src/ops/concat.rs +++ b/src/daft-micropartition/src/ops/concat.rs @@ -47,7 +47,7 @@ impl MicroPartition { } let new_len = all_tables.iter().map(|t| t.len()).sum(); - Ok(MicroPartition { + Ok(Self { schema: mps.first().unwrap().schema.clone(), state: Mutex::new(TableState::Loaded(all_tables.into())), metadata: TableMetadata { length: new_len }, diff --git a/src/daft-micropartition/src/ops/eval_expressions.rs b/src/daft-micropartition/src/ops/eval_expressions.rs index 7a9b8bed0e..8ac5966a2e 100644 --- a/src/daft-micropartition/src/ops/eval_expressions.rs +++ b/src/daft-micropartition/src/ops/eval_expressions.rs @@ -45,7 +45,7 @@ impl MicroPartition { .map(|s| s.eval_expression_list(exprs, &expected_schema)) .transpose()?; - Ok(MicroPartition::new_loaded( + Ok(Self::new_loaded( expected_schema.into(), Arc::new(evaluated_tables), eval_stats, @@ -85,7 +85,7 @@ impl MicroPartition { } } - Ok(MicroPartition::new_loaded( + Ok(Self::new_loaded( Arc::new(expected_schema), Arc::new(evaluated_tables), eval_stats, diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index aef268d669..bac67f12db 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -70,7 +70,7 @@ impl MicroPartition { ([], _) | (_, []) => Ok(Self::empty(Some(join_schema))), ([lt], [rt]) => { let joined_table = table_join(lt, rt, left_on, right_on, how)?; - Ok(MicroPartition::new_loaded( + Ok(Self::new_loaded( join_schema, vec![joined_table].into(), None, diff --git a/src/daft-micropartition/src/ops/partition.rs b/src/daft-micropartition/src/ops/partition.rs index d0358ff178..8ca24e276f 100644 --- a/src/daft-micropartition/src/ops/partition.rs +++ b/src/daft-micropartition/src/ops/partition.rs @@ -24,20 +24,11 @@ fn transpose2(v: Vec>) -> Vec> { } impl MicroPartition { - fn vec_part_tables_to_mps( - &self, - part_tables: Vec>, - ) -> DaftResult> { + fn vec_part_tables_to_mps(&self, part_tables: Vec>) -> DaftResult> { let part_tables = transpose2(part_tables); Ok(part_tables .into_iter() - .map(|v| { - MicroPartition::new_loaded( - self.schema.clone(), - Arc::new(v), - self.statistics.clone(), - ) - }) + .map(|v| Self::new_loaded(self.schema.clone(), Arc::new(v), self.statistics.clone())) .collect()) } @@ -127,11 +118,10 @@ impl MicroPartition { let mps = tables .into_iter() - .map(|t| MicroPartition::new_loaded(self.schema.clone(), Arc::new(vec![t]), None)) + .map(|t| Self::new_loaded(self.schema.clone(), Arc::new(vec![t]), None)) .collect::>(); - let values = - MicroPartition::new_loaded(values.schema.clone(), Arc::new(vec![values]), None); + let values = Self::new_loaded(values.schema.clone(), Arc::new(vec![values]), None); Ok((mps, values)) } diff --git a/src/daft-micropartition/src/ops/pivot.rs b/src/daft-micropartition/src/ops/pivot.rs index 90e6fa110e..3a4ad964b9 100644 --- a/src/daft-micropartition/src/ops/pivot.rs +++ b/src/daft-micropartition/src/ops/pivot.rs @@ -25,7 +25,7 @@ impl MicroPartition { } [t] => { let pivoted = t.pivot(group_by, pivot_col, values_col, names)?; - Ok(MicroPartition::new_loaded( + Ok(Self::new_loaded( pivoted.schema.clone(), vec![pivoted].into(), None, diff --git a/src/daft-micropartition/src/ops/slice.rs b/src/daft-micropartition/src/ops/slice.rs index fa6cb858f1..11c5f37dec 100644 --- a/src/daft-micropartition/src/ops/slice.rs +++ b/src/daft-micropartition/src/ops/slice.rs @@ -44,7 +44,7 @@ impl MicroPartition { } } - Ok(MicroPartition::new_loaded( + Ok(Self::new_loaded( self.schema.clone(), slices_tables.into(), self.statistics.clone(), diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 03423d1a37..2060b3de30 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -890,7 +890,7 @@ impl From for PyMicroPartition { impl From> for PyMicroPartition { fn from(value: Arc) -> Self { - PyMicroPartition { inner: value } + Self { inner: value } } } diff --git a/src/daft-minhash/Cargo.toml b/src/daft-minhash/Cargo.toml index d058339686..b902171b03 100644 --- a/src/daft-minhash/Cargo.toml +++ b/src/daft-minhash/Cargo.toml @@ -3,6 +3,9 @@ common-error = {path = "../common/error", default-features = false} fastrand = "2.1.0" mur3 = "0.1.0" +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-minhash" diff --git a/src/daft-parquet/Cargo.toml b/src/daft-parquet/Cargo.toml index 1ec75b3cb1..217ac9c96c 100644 --- a/src/daft-parquet/Cargo.toml +++ b/src/daft-parquet/Cargo.toml @@ -31,6 +31,9 @@ path_macro = {workspace = true} [features] python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-io/python", "daft-table/python", "daft-stats/python", "daft-dsl/python", "common-arrow-ffi/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-parquet" diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 8ba9f4ed25..2bcf080f0b 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -67,7 +67,7 @@ where S: futures::Stream> + std::marker::Unpin, { pub fn new(src: S, handle: tokio::runtime::Handle) -> Self { - StreamIterator { + Self { curr: None, src: tokio::sync::Mutex::new(src), handle, @@ -204,7 +204,7 @@ impl ParquetReaderBuilder { .await?; let metadata = read_parquet_metadata(uri, size, io_client, io_stats, field_id_mapping).await?; - Ok(ParquetReaderBuilder { + Ok(Self { uri: uri.into(), metadata, selected_columns: None, @@ -325,7 +325,7 @@ impl ParquetFileReader { row_ranges: Vec, chunk_size: Option, ) -> super::Result { - Ok(ParquetFileReader { + Ok(Self { uri, metadata: Arc::new(metadata), arrow_schema: arrow_schema.into(), diff --git a/src/daft-parquet/src/lib.rs b/src/daft-parquet/src/lib.rs index e907fec22e..039124e4a2 100644 --- a/src/daft-parquet/src/lib.rs +++ b/src/daft-parquet/src/lib.rs @@ -206,18 +206,18 @@ pub enum Error { } impl From for DaftError { - fn from(err: Error) -> DaftError { + fn from(err: Error) -> Self { match err { Error::DaftIOError { source } => source.into(), - Error::FileReadTimeout { .. } => DaftError::ReadTimeout(err.into()), - _ => DaftError::External(err.into()), + Error::FileReadTimeout { .. } => Self::ReadTimeout(err.into()), + _ => Self::External(err.into()), } } } impl From for Error { fn from(err: daft_io::Error) -> Self { - Error::DaftIOError { source: err } + Self::DaftIOError { source: err } } } diff --git a/src/daft-parquet/src/metadata.rs b/src/daft-parquet/src/metadata.rs index dd7c186145..c769262d8e 100644 --- a/src/daft-parquet/src/metadata.rs +++ b/src/daft-parquet/src/metadata.rs @@ -26,7 +26,7 @@ impl TreeNode for ParquetTypeWrapper { ParquetType::GroupType { fields, .. } => { for child in fields.iter() { // TODO: Expensive clone here because of ParquetTypeWrapper type, can we get rid of this? - match op(&ParquetTypeWrapper(child.clone()))? { + match op(&Self(child.clone()))? { TreeNodeRecursion::Continue => {} TreeNodeRecursion::Jump => return Ok(TreeNodeRecursion::Continue), TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), @@ -50,19 +50,15 @@ impl TreeNode for ParquetTypeWrapper { logical_type, converted_type, fields, - } => Ok(Transformed::yes(ParquetTypeWrapper( - ParquetType::GroupType { - fields: fields - .into_iter() - .map(|child| { - transform(ParquetTypeWrapper(child)).map(|wrapper| wrapper.data.0) - }) - .collect::>>()?, - field_info, - logical_type, - converted_type, - }, - ))), + } => Ok(Transformed::yes(Self(ParquetType::GroupType { + fields: fields + .into_iter() + .map(|child| transform(Self(child)).map(|wrapper| wrapper.data.0)) + .collect::>>()?, + field_info, + logical_type, + converted_type, + }))), } } } diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index eed16ae5b9..3b6c498cf6 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -48,7 +48,7 @@ impl TryFrom for ParquetSchemaInferenceOpt type Error = crate::Error; fn try_from(value: ParquetSchemaInferenceOptionsBuilder) -> crate::Result { - Ok(ParquetSchemaInferenceOptions { + Ok(Self { coerce_int96_timestamp_unit: value .coerce_int96_timestamp_unit .map_or(TimeUnit::Nanoseconds, From::from), @@ -77,7 +77,7 @@ impl ParquetSchemaInferenceOptions { pub fn new(coerce_int96_timestamp_unit: Option) -> Self { let coerce_int96_timestamp_unit = coerce_int96_timestamp_unit.unwrap_or(TimeUnit::Nanoseconds); - ParquetSchemaInferenceOptions { + Self { coerce_int96_timestamp_unit, ..Default::default() } @@ -86,7 +86,7 @@ impl ParquetSchemaInferenceOptions { impl Default for ParquetSchemaInferenceOptions { fn default() -> Self { - ParquetSchemaInferenceOptions { + Self { coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, string_encoding: StringEncoding::Utf8, } @@ -95,7 +95,7 @@ impl Default for ParquetSchemaInferenceOptions { impl From for SchemaInferenceOptions { fn from(value: ParquetSchemaInferenceOptions) -> Self { - SchemaInferenceOptions { + Self { int96_coerce_to_timeunit: value.coerce_int96_timestamp_unit.to_arrow(), string_encoding: value.string_encoding, } diff --git a/src/daft-parquet/src/read_planner.rs b/src/daft-parquet/src/read_planner.rs index 6337cd3330..aca3b3c870 100644 --- a/src/daft-parquet/src/read_planner.rs +++ b/src/daft-parquet/src/read_planner.rs @@ -132,7 +132,7 @@ pub(crate) struct ReadPlanner { impl ReadPlanner { pub fn new(source: &str) -> Self { - ReadPlanner { + Self { source: source.into(), ranges: vec![], passes: vec![], diff --git a/src/daft-parquet/src/statistics/mod.rs b/src/daft-parquet/src/statistics/mod.rs index 82dce1a5d8..2827c84355 100644 --- a/src/daft-parquet/src/statistics/mod.rs +++ b/src/daft-parquet/src/statistics/mod.rs @@ -26,7 +26,7 @@ pub(super) enum Error { impl From for Error { fn from(value: daft_stats::Error) -> Self { match value { - daft_stats::Error::DaftCoreCompute { source } => Error::DaftCoreCompute { source }, + daft_stats::Error::DaftCoreCompute { source } => Self::DaftCoreCompute { source }, _ => Self::DaftStats { source: value }, } } @@ -38,7 +38,7 @@ impl From for DaftError { fn from(value: Error) -> Self { match value { Error::DaftCoreCompute { source } => source, - _ => DaftError::External(value.into()), + _ => Self::External(value.into()), } } } @@ -47,6 +47,6 @@ pub(super) struct Wrap(T); impl From for Wrap { fn from(value: T) -> Self { - Wrap(value) + Self(value) } } diff --git a/src/daft-physical-plan/Cargo.toml b/src/daft-physical-plan/Cargo.toml index 9ba603ab52..778b8b8560 100644 --- a/src/daft-physical-plan/Cargo.toml +++ b/src/daft-physical-plan/Cargo.toml @@ -8,6 +8,9 @@ daft-scan = {path = "../daft-scan", default-features = false} log = {workspace = true} strum = {version = "0.26", features = ["derive"]} +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-physical-plan" diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index 30bc879a2f..548e6505d8 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -56,7 +56,7 @@ impl LocalPhysicalPlan { } pub(crate) fn in_memory_scan(in_memory_info: InMemoryInfo) -> LocalPhysicalPlanRef { - LocalPhysicalPlan::InMemoryScan(InMemoryScan { + Self::InMemoryScan(InMemoryScan { info: in_memory_info, plan_stats: PlanStats {}, }) @@ -67,7 +67,7 @@ impl LocalPhysicalPlan { scan_tasks: Vec, schema: SchemaRef, ) -> LocalPhysicalPlanRef { - LocalPhysicalPlan::PhysicalScan(PhysicalScan { + Self::PhysicalScan(PhysicalScan { scan_tasks, schema, plan_stats: PlanStats {}, @@ -77,7 +77,7 @@ impl LocalPhysicalPlan { pub(crate) fn filter(input: LocalPhysicalPlanRef, predicate: ExprRef) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); - LocalPhysicalPlan::Filter(Filter { + Self::Filter(Filter { input, predicate, schema, @@ -88,7 +88,7 @@ impl LocalPhysicalPlan { pub(crate) fn limit(input: LocalPhysicalPlanRef, num_rows: i64) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); - LocalPhysicalPlan::Limit(Limit { + Self::Limit(Limit { input, num_rows, schema, @@ -102,7 +102,7 @@ impl LocalPhysicalPlan { projection: Vec, schema: SchemaRef, ) -> LocalPhysicalPlanRef { - LocalPhysicalPlan::Project(Project { + Self::Project(Project { input, projection, schema, @@ -116,7 +116,7 @@ impl LocalPhysicalPlan { aggregations: Vec, schema: SchemaRef, ) -> LocalPhysicalPlanRef { - LocalPhysicalPlan::UnGroupedAggregate(UnGroupedAggregate { + Self::UnGroupedAggregate(UnGroupedAggregate { input, aggregations, schema, @@ -131,7 +131,7 @@ impl LocalPhysicalPlan { group_by: Vec, schema: SchemaRef, ) -> LocalPhysicalPlanRef { - LocalPhysicalPlan::HashAggregate(HashAggregate { + Self::HashAggregate(HashAggregate { input, aggregations, group_by, @@ -147,7 +147,7 @@ impl LocalPhysicalPlan { descending: Vec, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); - LocalPhysicalPlan::Sort(Sort { + Self::Sort(Sort { input, sort_by, descending, @@ -165,7 +165,7 @@ impl LocalPhysicalPlan { join_type: JoinType, schema: SchemaRef, ) -> LocalPhysicalPlanRef { - LocalPhysicalPlan::HashJoin(HashJoin { + Self::HashJoin(HashJoin { left, right, left_on, @@ -181,7 +181,7 @@ impl LocalPhysicalPlan { other: LocalPhysicalPlanRef, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); - LocalPhysicalPlan::Concat(Concat { + Self::Concat(Concat { input, other, schema, @@ -192,16 +192,16 @@ impl LocalPhysicalPlan { pub fn schema(&self) -> &SchemaRef { match self { - LocalPhysicalPlan::PhysicalScan(PhysicalScan { schema, .. }) - | LocalPhysicalPlan::Filter(Filter { schema, .. }) - | LocalPhysicalPlan::Limit(Limit { schema, .. }) - | LocalPhysicalPlan::Project(Project { schema, .. }) - | LocalPhysicalPlan::UnGroupedAggregate(UnGroupedAggregate { schema, .. }) - | LocalPhysicalPlan::HashAggregate(HashAggregate { schema, .. }) - | LocalPhysicalPlan::Sort(Sort { schema, .. }) - | LocalPhysicalPlan::HashJoin(HashJoin { schema, .. }) - | LocalPhysicalPlan::Concat(Concat { schema, .. }) => schema, - LocalPhysicalPlan::InMemoryScan(InMemoryScan { info, .. }) => &info.source_schema, + Self::PhysicalScan(PhysicalScan { schema, .. }) + | Self::Filter(Filter { schema, .. }) + | Self::Limit(Limit { schema, .. }) + | Self::Project(Project { schema, .. }) + | Self::UnGroupedAggregate(UnGroupedAggregate { schema, .. }) + | Self::HashAggregate(HashAggregate { schema, .. }) + | Self::Sort(Sort { schema, .. }) + | Self::HashJoin(HashJoin { schema, .. }) + | Self::Concat(Concat { schema, .. }) => schema, + Self::InMemoryScan(InMemoryScan { info, .. }) => &info.source_schema, _ => todo!("{:?}", self), } } diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index a8306394f1..7de191fe81 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -58,6 +58,9 @@ python = [ "daft-schema/python" ] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-plan" diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 98740a6a53..982a3634a9 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -53,8 +53,8 @@ impl LogicalPlanBuilder { } } -impl From<&LogicalPlanBuilder> for LogicalPlanBuilder { - fn from(builder: &LogicalPlanBuilder) -> Self { +impl From<&Self> for LogicalPlanBuilder { + fn from(builder: &Self) -> Self { Self { plan: builder.plan.clone(), config: builder.config.clone(), @@ -105,7 +105,7 @@ impl LogicalPlanBuilder { )); let logical_plan: LogicalPlan = logical_ops::Source::new(schema.clone(), source_info.into()).into(); - Ok(LogicalPlanBuilder::new(logical_plan.into(), None)) + Ok(Self::new(logical_plan.into(), None)) } pub fn table_scan( @@ -139,7 +139,7 @@ impl LogicalPlanBuilder { }; let logical_plan: LogicalPlan = logical_ops::Source::new(output_schema, source_info.into()).into(); - Ok(LogicalPlanBuilder::new(logical_plan.into(), None)) + Ok(Self::new(logical_plan.into(), None)) } pub fn select(&self, to_select: Vec) -> DaftResult { diff --git a/src/daft-plan/src/display.rs b/src/daft-plan/src/display.rs index a89276cc04..76ba0e599a 100644 --- a/src/daft-plan/src/display.rs +++ b/src/daft-plan/src/display.rs @@ -22,40 +22,40 @@ impl TreeDisplay for crate::LogicalPlan { impl TreeDisplay for crate::physical_plan::PhysicalPlan { fn display_as(&self, level: DisplayLevel) -> String { match self { - crate::PhysicalPlan::InMemoryScan(scan) => scan.display_as(level), - crate::PhysicalPlan::TabularScan(scan) => scan.display_as(level), - crate::PhysicalPlan::EmptyScan(scan) => scan.display_as(level), - crate::PhysicalPlan::Project(p) => p.display_as(level), - crate::PhysicalPlan::ActorPoolProject(p) => p.display_as(level), - crate::PhysicalPlan::Filter(f) => f.display_as(level), - crate::PhysicalPlan::Limit(limit) => limit.display_as(level), - crate::PhysicalPlan::Explode(explode) => explode.display_as(level), - crate::PhysicalPlan::Unpivot(unpivot) => unpivot.display_as(level), - crate::PhysicalPlan::Sort(sort) => sort.display_as(level), - crate::PhysicalPlan::Split(split) => split.display_as(level), - crate::PhysicalPlan::Sample(sample) => sample.display_as(level), - crate::PhysicalPlan::MonotonicallyIncreasingId(id) => id.display_as(level), - crate::PhysicalPlan::Coalesce(coalesce) => coalesce.display_as(level), - crate::PhysicalPlan::Flatten(flatten) => flatten.display_as(level), - crate::PhysicalPlan::FanoutRandom(fanout) => fanout.display_as(level), - crate::PhysicalPlan::FanoutByHash(fanout) => fanout.display_as(level), - crate::PhysicalPlan::FanoutByRange(fanout) => fanout.display_as(level), - crate::PhysicalPlan::ReduceMerge(reduce) => reduce.display_as(level), - crate::PhysicalPlan::Aggregate(aggr) => aggr.display_as(level), - crate::PhysicalPlan::Pivot(pivot) => pivot.display_as(level), - crate::PhysicalPlan::Concat(concat) => concat.display_as(level), - crate::PhysicalPlan::HashJoin(join) => join.display_as(level), - crate::PhysicalPlan::SortMergeJoin(join) => join.display_as(level), - crate::PhysicalPlan::BroadcastJoin(join) => join.display_as(level), - crate::PhysicalPlan::TabularWriteParquet(write) => write.display_as(level), - crate::PhysicalPlan::TabularWriteJson(write) => write.display_as(level), - crate::PhysicalPlan::TabularWriteCsv(write) => write.display_as(level), + Self::InMemoryScan(scan) => scan.display_as(level), + Self::TabularScan(scan) => scan.display_as(level), + Self::EmptyScan(scan) => scan.display_as(level), + Self::Project(p) => p.display_as(level), + Self::ActorPoolProject(p) => p.display_as(level), + Self::Filter(f) => f.display_as(level), + Self::Limit(limit) => limit.display_as(level), + Self::Explode(explode) => explode.display_as(level), + Self::Unpivot(unpivot) => unpivot.display_as(level), + Self::Sort(sort) => sort.display_as(level), + Self::Split(split) => split.display_as(level), + Self::Sample(sample) => sample.display_as(level), + Self::MonotonicallyIncreasingId(id) => id.display_as(level), + Self::Coalesce(coalesce) => coalesce.display_as(level), + Self::Flatten(flatten) => flatten.display_as(level), + Self::FanoutRandom(fanout) => fanout.display_as(level), + Self::FanoutByHash(fanout) => fanout.display_as(level), + Self::FanoutByRange(fanout) => fanout.display_as(level), + Self::ReduceMerge(reduce) => reduce.display_as(level), + Self::Aggregate(aggr) => aggr.display_as(level), + Self::Pivot(pivot) => pivot.display_as(level), + Self::Concat(concat) => concat.display_as(level), + Self::HashJoin(join) => join.display_as(level), + Self::SortMergeJoin(join) => join.display_as(level), + Self::BroadcastJoin(join) => join.display_as(level), + Self::TabularWriteParquet(write) => write.display_as(level), + Self::TabularWriteJson(write) => write.display_as(level), + Self::TabularWriteCsv(write) => write.display_as(level), #[cfg(feature = "python")] - crate::PhysicalPlan::IcebergWrite(write) => write.display_as(level), + Self::IcebergWrite(write) => write.display_as(level), #[cfg(feature = "python")] - crate::PhysicalPlan::DeltaLakeWrite(write) => write.display_as(level), + Self::DeltaLakeWrite(write) => write.display_as(level), #[cfg(feature = "python")] - crate::PhysicalPlan::LanceWrite(write) => write.display_as(level), + Self::LanceWrite(write) => write.display_as(level), } } diff --git a/src/daft-plan/src/logical_ops/actor_pool_project.rs b/src/daft-plan/src/logical_ops/actor_pool_project.rs index b353636240..97b511b238 100644 --- a/src/daft-plan/src/logical_ops/actor_pool_project.rs +++ b/src/daft-plan/src/logical_ops/actor_pool_project.rs @@ -58,7 +58,7 @@ impl ActorPoolProject { let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); - Ok(ActorPoolProject { + Ok(Self { input, projection, projected_schema, diff --git a/src/daft-plan/src/logical_optimization/optimizer.rs b/src/daft-plan/src/logical_optimization/optimizer.rs index 5e065e9213..535eb16448 100644 --- a/src/daft-plan/src/logical_optimization/optimizer.rs +++ b/src/daft-plan/src/logical_optimization/optimizer.rs @@ -23,7 +23,7 @@ pub struct OptimizerConfig { impl OptimizerConfig { fn new(max_optimizer_passes: usize, enable_actor_pool_projections: bool) -> Self { - OptimizerConfig { + Self { default_max_optimizer_passes: max_optimizer_passes, enable_actor_pool_projections, } @@ -33,7 +33,7 @@ impl OptimizerConfig { impl Default for OptimizerConfig { fn default() -> Self { // Default to a max of 5 optimizer passes for a given batch. - OptimizerConfig::new(5, false) + Self::new(5, false) } } diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 2cb2ca4834..849a0402f8 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -214,7 +214,7 @@ impl LogicalPlan { } } - pub fn children(&self) -> Vec<&LogicalPlan> { + pub fn children(&self) -> Vec<&Self> { match self { Self::Source(..) => vec![], Self::Project(Project { input, .. }) => vec![input], @@ -238,7 +238,7 @@ impl LogicalPlan { } } - pub fn with_new_children(&self, children: &[Arc]) -> LogicalPlan { + pub fn with_new_children(&self, children: &[Arc]) -> Self { match children { [input] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), @@ -309,8 +309,8 @@ pub(crate) enum Error { pub(crate) type Result = std::result::Result; impl From for DaftError { - fn from(err: Error) -> DaftError { - DaftError::External(err.into()) + fn from(err: Error) -> Self { + Self::External(err.into()) } } diff --git a/src/daft-plan/src/partitioning.rs b/src/daft-plan/src/partitioning.rs index a460437a5a..9f5be356be 100644 --- a/src/daft-plan/src/partitioning.rs +++ b/src/daft-plan/src/partitioning.rs @@ -428,6 +428,6 @@ impl UnknownClusteringConfig { impl Default for UnknownClusteringConfig { fn default() -> Self { - UnknownClusteringConfig::new(1) + Self::new(1) } } diff --git a/src/daft-plan/src/physical_ops/actor_pool_project.rs b/src/daft-plan/src/physical_ops/actor_pool_project.rs index ae55af05b6..bca86ebf08 100644 --- a/src/daft-plan/src/physical_ops/actor_pool_project.rs +++ b/src/daft-plan/src/physical_ops/actor_pool_project.rs @@ -53,7 +53,7 @@ impl ActorPoolProject { return Err(DaftError::InternalError(format!("Expected ActorPoolProject to have exactly 1 stateful UDF expression but found: {num_stateful_udf_exprs}"))); } - Ok(ActorPoolProject { + Ok(Self { input, projection, clustering_spec, diff --git a/src/daft-plan/src/physical_optimization/optimizer.rs b/src/daft-plan/src/physical_optimization/optimizer.rs index 51e2b7ade3..58e2d43e53 100644 --- a/src/daft-plan/src/physical_optimization/optimizer.rs +++ b/src/daft-plan/src/physical_optimization/optimizer.rs @@ -16,13 +16,13 @@ pub struct PhysicalOptimizerConfig { impl PhysicalOptimizerConfig { #[allow(dead_code)] // used in test pub fn new(max_passes: usize) -> Self { - PhysicalOptimizerConfig { max_passes } + Self { max_passes } } } impl Default for PhysicalOptimizerConfig { fn default() -> Self { - PhysicalOptimizerConfig { max_passes: 5 } + Self { max_passes: 5 } } } @@ -37,7 +37,7 @@ impl PhysicalOptimizer { rule_batches: Vec, config: PhysicalOptimizerConfig, ) -> Self { - PhysicalOptimizer { + Self { rule_batches, config, } @@ -53,7 +53,7 @@ impl PhysicalOptimizer { impl Default for PhysicalOptimizer { fn default() -> Self { - PhysicalOptimizer { + Self { rule_batches: vec![PhysicalOptimizerRuleBatch::new( vec![ Box::new(ReorderPartitionKeys {}), diff --git a/src/daft-plan/src/physical_optimization/rules/rule.rs b/src/daft-plan/src/physical_optimization/rules/rule.rs index 8d20891424..ec49563d04 100644 --- a/src/daft-plan/src/physical_optimization/rules/rule.rs +++ b/src/daft-plan/src/physical_optimization/rules/rule.rs @@ -30,7 +30,7 @@ impl PhysicalOptimizerRuleBatch { rules: Vec>, strategy: PhysicalRuleExecutionStrategy, ) -> Self { - PhysicalOptimizerRuleBatch { rules, strategy } + Self { rules, strategy } } fn optimize_once(&self, plan: PhysicalPlanRef) -> DaftResult> { diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 6302612e3d..615d656b92 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -62,7 +62,7 @@ pub struct ApproxStats { impl ApproxStats { fn empty() -> Self { - ApproxStats { + Self { lower_bound_rows: 0, upper_bound_rows: None, lower_bound_bytes: 0, @@ -70,7 +70,7 @@ impl ApproxStats { } } fn apply usize>(&self, f: F) -> Self { - ApproxStats { + Self { lower_bound_rows: f(self.lower_bound_rows), upper_bound_rows: self.upper_bound_rows.map(&f), lower_bound_bytes: f(self.lower_bound_rows), @@ -411,7 +411,7 @@ impl PhysicalPlan { } } - pub fn children(&self) -> Vec<&PhysicalPlan> { + pub fn children(&self) -> Vec<&Self> { match self { Self::InMemoryScan(..) => vec![], Self::TabularScan(..) | Self::EmptyScan(..) => vec![], @@ -457,7 +457,7 @@ impl PhysicalPlan { } } - pub fn with_new_children(&self, children: &[PhysicalPlanRef]) -> PhysicalPlan { + pub fn with_new_children(&self, children: &[PhysicalPlanRef]) -> Self { match children { [input] => match self { Self::InMemoryScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), diff --git a/src/daft-plan/src/physical_planner/planner.rs b/src/daft-plan/src/physical_planner/planner.rs index 685f8dd028..5071c1bce2 100644 --- a/src/daft-plan/src/physical_planner/planner.rs +++ b/src/daft-plan/src/physical_planner/planner.rs @@ -258,18 +258,18 @@ pub enum QueryStageOutput { impl QueryStageOutput { pub fn unwrap(self) -> (Option, PhysicalPlanRef) { match self { - QueryStageOutput::Partial { + Self::Partial { physical_plan, source_id, } => (Some(source_id), physical_plan), - QueryStageOutput::Final { physical_plan } => (None, physical_plan), + Self::Final { physical_plan } => (None, physical_plan), } } pub fn source_id(&self) -> Option { match self { - QueryStageOutput::Partial { source_id, .. } => Some(*source_id), - QueryStageOutput::Final { .. } => None, + Self::Partial { source_id, .. } => Some(*source_id), + Self::Final { .. } => None, } } } @@ -293,7 +293,7 @@ pub struct AdaptivePlanner { impl AdaptivePlanner { pub fn new(logical_plan: LogicalPlanRef, cfg: Arc) -> Self { - AdaptivePlanner { + Self { logical_plan, cfg, status: AdaptivePlannerStatus::Ready, diff --git a/src/daft-plan/src/source_info/mod.rs b/src/daft-plan/src/source_info/mod.rs index ef5582ca3a..360b4f7d7c 100644 --- a/src/daft-plan/src/source_info/mod.rs +++ b/src/daft-plan/src/source_info/mod.rs @@ -87,7 +87,7 @@ pub struct PlaceHolderInfo { impl PlaceHolderInfo { pub fn new(source_schema: SchemaRef, clustering_spec: ClusteringSpecRef) -> Self { - PlaceHolderInfo { + Self { source_schema, clustering_spec, source_id: PLACEHOLDER_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst), diff --git a/src/daft-scan/Cargo.toml b/src/daft-scan/Cargo.toml index 7f9added6d..d4c5e5a230 100644 --- a/src/daft-scan/Cargo.toml +++ b/src/daft-scan/Cargo.toml @@ -24,6 +24,9 @@ snafu = {workspace = true} [features] python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-dsl/python", "daft-table/python", "daft-stats/python", "common-file-formats/python", "common-io-config/python", "common-daft-config/python", "daft-schema/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-scan" diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 90621e510d..376548f7a7 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -49,7 +49,7 @@ enum Error { impl From for DaftError { fn from(value: Error) -> Self { match &value { - Error::GlobNoMatch { glob_path } => DaftError::FileNotFound { + Error::GlobNoMatch { glob_path } => Self::FileNotFound { path: glob_path.clone(), source: Box::new(value), }, diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index f8c9781be4..10cc0c6804 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -85,7 +85,7 @@ pub enum Error { impl From for DaftError { fn from(value: Error) -> Self { - DaftError::External(value.into()) + Self::External(value.into()) } } @@ -409,7 +409,7 @@ impl ScanTask { } } - pub fn merge(sc1: &ScanTask, sc2: &ScanTask) -> Result { + pub fn merge(sc1: &Self, sc2: &Self) -> Result { if sc1.partition_spec() != sc2.partition_spec() { return Err(Error::DifferingPartitionSpecsInScanTaskMerge { ps1: sc1.partition_spec().cloned(), @@ -440,7 +440,7 @@ impl ScanTask { p2: sc2.pushdowns.clone(), }); } - Ok(ScanTask::new( + Ok(Self::new( sc1.sources .clone() .into_iter() @@ -676,7 +676,7 @@ impl PartitionField { match (&source_field, &transform) { (Some(_), Some(_)) => { // TODO ADD VALIDATION OF TRANSFORM based on types - Ok(PartitionField { + Ok(Self { field, source_field, transform, @@ -686,7 +686,7 @@ impl PartitionField { "transform set in PartitionField: {} but source_field not set", tfm ))), - _ => Ok(PartitionField { + _ => Ok(Self { field, source_field, transform, @@ -787,8 +787,8 @@ impl Hash for ScanOperatorRef { } } -impl PartialEq for ScanOperatorRef { - fn eq(&self, other: &ScanOperatorRef) -> bool { +impl PartialEq for ScanOperatorRef { + fn eq(&self, other: &Self) -> bool { Arc::ptr_eq(&self.0, &other.0) } } @@ -1014,7 +1014,7 @@ mod test { let mut sources: Vec = Vec::new(); for _ in 0..num_sources { - sources.push(format!("../../tests/assets/parquet-data/mvp.parquet")); + sources.push("../../tests/assets/parquet-data/mvp.parquet".to_string()); } let glob_scan_operator: GlobScanOperator = GlobScanOperator::try_new( diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index af5b23a1db..fac37ccb48 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -93,7 +93,7 @@ pub mod pylib { file_format_config.into(), storage_config.into(), )); - Ok(ScanOperatorHandle { + Ok(Self { scan_op: ScanOperatorRef(operator), }) }) @@ -116,7 +116,7 @@ pub mod pylib { infer_schema, schema.map(|s| s.schema), )?); - Ok(ScanOperatorHandle { + Ok(Self { scan_op: ScanOperatorRef(operator), }) }) @@ -127,7 +127,7 @@ pub mod pylib { let scan_op = ScanOperatorRef(Arc::new(PythonScanOperatorBridge::from_python_abc( py_scan, py, )?)); - Ok(ScanOperatorHandle { scan_op }) + Ok(Self { scan_op }) } } #[pyclass(module = "daft.daft")] @@ -349,7 +349,7 @@ pub mod pylib { storage_config.into(), pushdowns.map(|p| p.0.as_ref().clone()).unwrap_or_default(), ); - Ok(Some(PyScanTask(scan_task.into()))) + Ok(Some(Self(scan_task.into()))) } #[allow(clippy::too_many_arguments)] @@ -381,7 +381,7 @@ pub mod pylib { storage_config.into(), pushdowns.map(|p| p.0.as_ref().clone()).unwrap_or_default(), ); - Ok(PyScanTask(scan_task.into())) + Ok(Self(scan_task.into())) } #[allow(clippy::too_many_arguments)] @@ -425,7 +425,7 @@ pub mod pylib { ))), pushdowns.map(|p| p.0.as_ref().clone()).unwrap_or_default(), ); - Ok(PyScanTask(scan_task.into())) + Ok(Self(scan_task.into())) } pub fn __repr__(&self) -> PyResult { @@ -464,7 +464,7 @@ pub mod pylib { source_field.map(|f| f.into()), transform.map(|e| e.0), )?; - Ok(PyPartitionField(Arc::new(p_field))) + Ok(Self(Arc::new(p_field))) } pub fn __repr__(&self) -> PyResult { diff --git a/src/daft-scan/src/storage_config.rs b/src/daft-scan/src/storage_config.rs index ced4c95315..d169e06510 100644 --- a/src/daft-scan/src/storage_config.rs +++ b/src/daft-scan/src/storage_config.rs @@ -26,7 +26,7 @@ impl StorageConfig { // Grab an IOClient and Runtime // TODO: This should be cleaned up and hidden behind a better API from daft-io match self { - StorageConfig::Native(cfg) => { + Self::Native(cfg) => { let multithreaded_io = cfg.multithreaded_io; Ok(( get_runtime(multithreaded_io)?, @@ -37,7 +37,7 @@ impl StorageConfig { )) } #[cfg(feature = "python")] - StorageConfig::Python(cfg) => { + Self::Python(cfg) => { let multithreaded_io = true; // Hardcode to use multithreaded IO if Python storage config is used for data fetches Ok(( get_runtime(multithreaded_io)?, diff --git a/src/daft-scheduler/Cargo.toml b/src/daft-scheduler/Cargo.toml index bfd4ca7985..14228f79b9 100644 --- a/src/daft-scheduler/Cargo.toml +++ b/src/daft-scheduler/Cargo.toml @@ -29,6 +29,9 @@ python = [ "daft-dsl/python" ] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-scheduler" diff --git a/src/daft-scheduler/src/adaptive.rs b/src/daft-scheduler/src/adaptive.rs index 0e4a2bbc77..e701bd0e73 100644 --- a/src/daft-scheduler/src/adaptive.rs +++ b/src/daft-scheduler/src/adaptive.rs @@ -17,7 +17,7 @@ pub struct AdaptivePhysicalPlanScheduler { impl AdaptivePhysicalPlanScheduler { pub fn new(logical_plan: Arc, cfg: Arc) -> Self { - AdaptivePhysicalPlanScheduler { + Self { planner: AdaptivePlanner::new(logical_plan, cfg), } } @@ -34,10 +34,7 @@ impl AdaptivePhysicalPlanScheduler { ) -> PyResult { py.allow_threads(|| { let logical_plan = logical_plan_builder.builder.build(); - Ok(AdaptivePhysicalPlanScheduler::new( - logical_plan, - cfg.config.clone(), - )) + Ok(Self::new(logical_plan, cfg.config.clone())) }) } pub fn next(&mut self, py: Python) -> PyResult<(Option, PhysicalPlanScheduler)> { diff --git a/src/daft-schema/Cargo.toml b/src/daft-schema/Cargo.toml index df6db1718c..ed6ecde2b7 100644 --- a/src/daft-schema/Cargo.toml +++ b/src/daft-schema/Cargo.toml @@ -23,6 +23,9 @@ python = [ "common-arrow-ffi/python" ] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-schema" diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index c7418b5549..65cf8f808e 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -171,7 +171,7 @@ struct DataTypePayload { impl DataTypePayload { pub fn new(datatype: &DataType) -> Self { - DataTypePayload { + Self { datatype: datatype.clone(), daft_version: common_version::VERSION.into(), daft_build_type: common_version::DAFT_BUILD_TYPE.into(), @@ -181,48 +181,48 @@ impl DataTypePayload { const DAFT_SUPER_EXTENSION_NAME: &str = "daft.super_extension"; impl DataType { - pub fn new_null() -> DataType { - DataType::Null + pub fn new_null() -> Self { + Self::Null } - pub fn new_list(datatype: DataType) -> DataType { - DataType::List(Box::new(datatype)) + pub fn new_list(datatype: Self) -> Self { + Self::List(Box::new(datatype)) } - pub fn new_fixed_size_list(datatype: DataType, size: usize) -> DataType { - DataType::FixedSizeList(Box::new(datatype), size) + pub fn new_fixed_size_list(datatype: Self, size: usize) -> Self { + Self::FixedSizeList(Box::new(datatype), size) } pub fn to_arrow(&self) -> DaftResult { match self { - DataType::Null => Ok(ArrowType::Null), - DataType::Boolean => Ok(ArrowType::Boolean), - DataType::Int8 => Ok(ArrowType::Int8), - DataType::Int16 => Ok(ArrowType::Int16), - DataType::Int32 => Ok(ArrowType::Int32), - DataType::Int64 => Ok(ArrowType::Int64), + Self::Null => Ok(ArrowType::Null), + Self::Boolean => Ok(ArrowType::Boolean), + Self::Int8 => Ok(ArrowType::Int8), + Self::Int16 => Ok(ArrowType::Int16), + Self::Int32 => Ok(ArrowType::Int32), + Self::Int64 => Ok(ArrowType::Int64), // Must maintain same default mapping as Arrow2, otherwise this will throw errors in // DataArray::new() which makes strong assumptions about the arrow/Daft types // https://github.com/jorgecarleitao/arrow2/blob/b0734542c2fef5d2d0c7b6ffce5d094de371168a/src/datatypes/mod.rs#L493 - DataType::Int128 => Ok(ArrowType::Decimal(32, 32)), - DataType::UInt8 => Ok(ArrowType::UInt8), - DataType::UInt16 => Ok(ArrowType::UInt16), - DataType::UInt32 => Ok(ArrowType::UInt32), - DataType::UInt64 => Ok(ArrowType::UInt64), + Self::Int128 => Ok(ArrowType::Decimal(32, 32)), + Self::UInt8 => Ok(ArrowType::UInt8), + Self::UInt16 => Ok(ArrowType::UInt16), + Self::UInt32 => Ok(ArrowType::UInt32), + Self::UInt64 => Ok(ArrowType::UInt64), // DataType::Float16 => Ok(ArrowType::Float16), - DataType::Float32 => Ok(ArrowType::Float32), - DataType::Float64 => Ok(ArrowType::Float64), - DataType::Decimal128(precision, scale) => Ok(ArrowType::Decimal(*precision, *scale)), - DataType::Timestamp(unit, timezone) => { + Self::Float32 => Ok(ArrowType::Float32), + Self::Float64 => Ok(ArrowType::Float64), + Self::Decimal128(precision, scale) => Ok(ArrowType::Decimal(*precision, *scale)), + Self::Timestamp(unit, timezone) => { Ok(ArrowType::Timestamp(unit.to_arrow(), timezone.clone())) } - DataType::Date => Ok(ArrowType::Date32), - DataType::Time(unit) => Ok(ArrowType::Time64(unit.to_arrow())), - DataType::Duration(unit) => Ok(ArrowType::Duration(unit.to_arrow())), - DataType::Binary => Ok(ArrowType::LargeBinary), - DataType::FixedSizeBinary(size) => Ok(ArrowType::FixedSizeBinary(*size)), - DataType::Utf8 => Ok(ArrowType::LargeUtf8), - DataType::FixedSizeList(child_dtype, size) => Ok(ArrowType::FixedSizeList( + Self::Date => Ok(ArrowType::Date32), + Self::Time(unit) => Ok(ArrowType::Time64(unit.to_arrow())), + Self::Duration(unit) => Ok(ArrowType::Duration(unit.to_arrow())), + Self::Binary => Ok(ArrowType::LargeBinary), + Self::FixedSizeBinary(size) => Ok(ArrowType::FixedSizeBinary(*size)), + Self::Utf8 => Ok(ArrowType::LargeUtf8), + Self::FixedSizeList(child_dtype, size) => Ok(ArrowType::FixedSizeList( Box::new(arrow2::datatypes::Field::new( "item", child_dtype.to_arrow()?, @@ -230,10 +230,10 @@ impl DataType { )), *size, )), - DataType::List(field) => Ok(ArrowType::LargeList(Box::new( + Self::List(field) => Ok(ArrowType::LargeList(Box::new( arrow2::datatypes::Field::new("item", field.to_arrow()?, true), ))), - DataType::Map(field) => Ok(ArrowType::Map( + Self::Map(field) => Ok(ArrowType::Map( Box::new(arrow2::datatypes::Field::new( "item", field.to_arrow()?, @@ -241,27 +241,27 @@ impl DataType { )), false, )), - DataType::Struct(fields) => Ok({ + Self::Struct(fields) => Ok({ let fields = fields .iter() .map(|f| f.to_arrow()) .collect::>>()?; ArrowType::Struct(fields) }), - DataType::Extension(name, dtype, metadata) => Ok(ArrowType::Extension( + Self::Extension(name, dtype, metadata) => Ok(ArrowType::Extension( name.clone(), Box::new(dtype.to_arrow()?), metadata.clone(), )), - DataType::Embedding(..) - | DataType::Image(..) - | DataType::FixedShapeImage(..) - | DataType::Tensor(..) - | DataType::FixedShapeTensor(..) - | DataType::SparseTensor(..) - | DataType::FixedShapeSparseTensor(..) => { + Self::Embedding(..) + | Self::Image(..) + | Self::FixedShapeImage(..) + | Self::Tensor(..) + | Self::FixedShapeTensor(..) + | Self::SparseTensor(..) + | Self::FixedShapeSparseTensor(..) => { let physical = Box::new(self.to_physical()); - let logical_extension = DataType::Extension( + let logical_extension = Self::Extension( DAFT_SUPER_EXTENSION_NAME.into(), physical, Some(self.to_json()?), @@ -269,16 +269,16 @@ impl DataType { logical_extension.to_arrow() } #[cfg(feature = "python")] - DataType::Python => Err(DaftError::TypeError(format!( + Self::Python => Err(DaftError::TypeError(format!( "Can not convert {self:?} into arrow type" ))), - DataType::Unknown => Err(DaftError::TypeError(format!( + Self::Unknown => Err(DaftError::TypeError(format!( "Can not convert {self:?} into arrow type" ))), } } - pub fn to_physical(&self) -> DataType { + pub fn to_physical(&self) -> Self { use DataType::*; match self { Decimal128(..) => Int128, @@ -293,7 +293,7 @@ impl DataType { Image(mode) => Struct(vec![ Field::new( "data", - List(Box::new(mode.map_or(DataType::UInt8, |m| m.get_dtype()))), + List(Box::new(mode.map_or(Self::UInt8, |m| m.get_dtype()))), ), Field::new("channel", UInt16), Field::new("height", UInt32), @@ -306,7 +306,7 @@ impl DataType { ), Tensor(dtype) => Struct(vec![ Field::new("data", List(Box::new(*dtype.clone()))), - Field::new("shape", List(Box::new(DataType::UInt64))), + Field::new("shape", List(Box::new(Self::UInt64))), ]), FixedShapeTensor(dtype, shape) => FixedSizeList( Box::new(*dtype.clone()), @@ -314,12 +314,12 @@ impl DataType { ), SparseTensor(dtype) => Struct(vec![ Field::new("values", List(Box::new(*dtype.clone()))), - Field::new("indices", List(Box::new(DataType::UInt64))), - Field::new("shape", List(Box::new(DataType::UInt64))), + Field::new("indices", List(Box::new(Self::UInt64))), + Field::new("shape", List(Box::new(Self::UInt64))), ]), FixedShapeSparseTensor(dtype, _) => Struct(vec![ Field::new("values", List(Box::new(*dtype.clone()))), - Field::new("indices", List(Box::new(DataType::UInt64))), + Field::new("indices", List(Box::new(Self::UInt64))), ]), _ => { assert!(self.is_physical()); @@ -329,15 +329,15 @@ impl DataType { } #[inline] - pub fn nested_dtype(&self) -> Option<&DataType> { + pub fn nested_dtype(&self) -> Option<&Self> { match self { - DataType::Map(dtype) - | DataType::List(dtype) - | DataType::FixedSizeList(dtype, _) - | DataType::FixedShapeTensor(dtype, _) - | DataType::SparseTensor(dtype) - | DataType::FixedShapeSparseTensor(dtype, _) - | DataType::Tensor(dtype) => Some(dtype), + Self::Map(dtype) + | Self::List(dtype) + | Self::FixedSizeList(dtype, _) + | Self::FixedShapeTensor(dtype, _) + | Self::SparseTensor(dtype) + | Self::FixedShapeSparseTensor(dtype, _) + | Self::Tensor(dtype) => Some(dtype), _ => None, } } @@ -350,19 +350,19 @@ impl DataType { #[inline] pub fn is_numeric(&self) -> bool { match self { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Int128 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 + Self::Int8 + | Self::Int16 + | Self::Int32 + | Self::Int64 + | Self::Int128 + | Self::UInt8 + | Self::UInt16 + | Self::UInt32 + | Self::UInt64 // DataType::Float16 - | DataType::Float32 - | DataType::Float64 => true, - DataType::Extension(_, inner, _) => inner.is_numeric(), + | Self::Float32 + | Self::Float64 => true, + Self::Extension(_, inner, _) => inner.is_numeric(), _ => false } } @@ -370,10 +370,10 @@ impl DataType { #[inline] pub fn is_fixed_size_numeric(&self) -> bool { match self { - DataType::FixedSizeList(dtype, ..) - | DataType::Embedding(dtype, ..) - | DataType::FixedShapeTensor(dtype, ..) - | DataType::FixedShapeSparseTensor(dtype, ..) => dtype.is_numeric(), + Self::FixedSizeList(dtype, ..) + | Self::Embedding(dtype, ..) + | Self::FixedShapeTensor(dtype, ..) + | Self::FixedShapeSparseTensor(dtype, ..) => dtype.is_numeric(), _ => false, } } @@ -381,8 +381,8 @@ impl DataType { #[inline] pub fn fixed_size(&self) -> Option { match self { - DataType::FixedSizeList(_, size) => Some(*size), - DataType::Embedding(_, size) => Some(*size), + Self::FixedSizeList(_, size) => Some(*size), + Self::Embedding(_, size) => Some(*size), _ => None, } } @@ -391,15 +391,15 @@ impl DataType { pub fn is_integer(&self) -> bool { matches!( self, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Int128 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 + Self::Int8 + | Self::Int16 + | Self::Int32 + | Self::Int64 + | Self::Int128 + | Self::UInt8 + | Self::UInt16 + | Self::UInt32 + | Self::UInt64 ) } @@ -408,89 +408,89 @@ impl DataType { matches!( self, // DataType::Float16 | - DataType::Float32 | DataType::Float64 + Self::Float32 | Self::Float64 ) } #[inline] pub fn is_temporal(&self) -> bool { match self { - DataType::Date | DataType::Timestamp(..) => true, - DataType::Extension(_, inner, _) => inner.is_temporal(), + Self::Date | Self::Timestamp(..) => true, + Self::Extension(_, inner, _) => inner.is_temporal(), _ => false, } } #[inline] pub fn is_tensor(&self) -> bool { - matches!(self, DataType::Tensor(..)) + matches!(self, Self::Tensor(..)) } #[inline] pub fn is_sparse_tensor(&self) -> bool { - matches!(self, DataType::SparseTensor(..)) + matches!(self, Self::SparseTensor(..)) } #[inline] pub fn is_fixed_shape_tensor(&self) -> bool { - matches!(self, DataType::FixedShapeTensor(..)) + matches!(self, Self::FixedShapeTensor(..)) } #[inline] pub fn is_fixed_shape_sparse_tensor(&self) -> bool { - matches!(self, DataType::FixedShapeSparseTensor(..)) + matches!(self, Self::FixedShapeSparseTensor(..)) } #[inline] pub fn is_image(&self) -> bool { - matches!(self, DataType::Image(..)) + matches!(self, Self::Image(..)) } #[inline] pub fn is_fixed_shape_image(&self) -> bool { - matches!(self, DataType::FixedShapeImage(..)) + matches!(self, Self::FixedShapeImage(..)) } #[inline] pub fn is_map(&self) -> bool { - matches!(self, DataType::Map(..)) + matches!(self, Self::Map(..)) } #[inline] pub fn is_list(&self) -> bool { - matches!(self, DataType::List(..)) + matches!(self, Self::List(..)) } #[inline] pub fn is_string(&self) -> bool { - matches!(self, DataType::Utf8) + matches!(self, Self::Utf8) } #[inline] pub fn is_boolean(&self) -> bool { - matches!(self, DataType::Boolean) + matches!(self, Self::Boolean) } #[inline] pub fn is_null(&self) -> bool { match self { - DataType::Null => true, - DataType::Extension(_, inner, _) => inner.is_null(), + Self::Null => true, + Self::Extension(_, inner, _) => inner.is_null(), _ => false, } } #[inline] pub fn is_extension(&self) -> bool { - matches!(self, DataType::Extension(..)) + matches!(self, Self::Extension(..)) } #[inline] pub fn is_python(&self) -> bool { match self { #[cfg(feature = "python")] - DataType::Python => true, - DataType::Extension(_, inner, _) => inner.is_python(), + Self::Python => true, + Self::Extension(_, inner, _) => inner.is_python(), _ => false, } } @@ -499,18 +499,18 @@ impl DataType { pub fn to_floating_representation(&self) -> DaftResult { let data_type = match self { // All numeric types that coerce to `f32` - DataType::Int8 => DataType::Float32, - DataType::Int16 => DataType::Float32, - DataType::UInt8 => DataType::Float32, - DataType::UInt16 => DataType::Float32, - DataType::Float32 => DataType::Float32, + Self::Int8 => Self::Float32, + Self::Int16 => Self::Float32, + Self::UInt8 => Self::Float32, + Self::UInt16 => Self::Float32, + Self::Float32 => Self::Float32, // All numeric types that coerce to `f64` - DataType::Int32 => DataType::Float64, - DataType::Int64 => DataType::Float64, - DataType::UInt32 => DataType::Float64, - DataType::UInt64 => DataType::Float64, - DataType::Float64 => DataType::Float64, + Self::Int32 => Self::Float64, + Self::Int64 => Self::Float64, + Self::UInt32 => Self::Float64, + Self::UInt64 => Self::Float64, + Self::Float64 => Self::Float64, _ => { return Err(DaftError::TypeError(format!( @@ -527,33 +527,33 @@ impl DataType { const DEFAULT_LIST_LEN: f64 = 4.; let elem_size = match self.to_physical() { - DataType::Null => Some(0.), - DataType::Boolean => Some(0.125), - DataType::Int8 => Some(1.), - DataType::Int16 => Some(2.), - DataType::Int32 => Some(4.), - DataType::Int64 => Some(8.), - DataType::Int128 => Some(16.), - DataType::UInt8 => Some(1.), - DataType::UInt16 => Some(2.), - DataType::UInt32 => Some(4.), - DataType::UInt64 => Some(8.), - DataType::Float32 => Some(4.), - DataType::Float64 => Some(8.), - DataType::Utf8 => Some(VARIABLE_TYPE_SIZE), - DataType::Binary => Some(VARIABLE_TYPE_SIZE), - DataType::FixedSizeBinary(size) => Some(size as f64), - DataType::FixedSizeList(dtype, len) => { + Self::Null => Some(0.), + Self::Boolean => Some(0.125), + Self::Int8 => Some(1.), + Self::Int16 => Some(2.), + Self::Int32 => Some(4.), + Self::Int64 => Some(8.), + Self::Int128 => Some(16.), + Self::UInt8 => Some(1.), + Self::UInt16 => Some(2.), + Self::UInt32 => Some(4.), + Self::UInt64 => Some(8.), + Self::Float32 => Some(4.), + Self::Float64 => Some(8.), + Self::Utf8 => Some(VARIABLE_TYPE_SIZE), + Self::Binary => Some(VARIABLE_TYPE_SIZE), + Self::FixedSizeBinary(size) => Some(size as f64), + Self::FixedSizeList(dtype, len) => { dtype.estimate_size_bytes().map(|b| b * (len as f64)) } - DataType::List(dtype) => dtype.estimate_size_bytes().map(|b| b * DEFAULT_LIST_LEN), - DataType::Struct(fields) => Some( + Self::List(dtype) => dtype.estimate_size_bytes().map(|b| b * DEFAULT_LIST_LEN), + Self::Struct(fields) => Some( fields .iter() .map(|f| f.dtype.estimate_size_bytes().unwrap_or(0f64)) .sum(), ), - DataType::Extension(_, dtype, _) => dtype.estimate_size_bytes(), + Self::Extension(_, dtype, _) => dtype.estimate_size_bytes(), _ => None, }; // add bitmap @@ -564,19 +564,19 @@ impl DataType { pub fn is_logical(&self) -> bool { matches!( self, - DataType::Decimal128(..) - | DataType::Date - | DataType::Time(..) - | DataType::Timestamp(..) - | DataType::Duration(..) - | DataType::Embedding(..) - | DataType::Image(..) - | DataType::FixedShapeImage(..) - | DataType::Tensor(..) - | DataType::FixedShapeTensor(..) - | DataType::SparseTensor(..) - | DataType::FixedShapeSparseTensor(..) - | DataType::Map(..) + Self::Decimal128(..) + | Self::Date + | Self::Time(..) + | Self::Timestamp(..) + | Self::Duration(..) + | Self::Embedding(..) + | Self::Image(..) + | Self::FixedShapeImage(..) + | Self::Tensor(..) + | Self::FixedShapeTensor(..) + | Self::SparseTensor(..) + | Self::FixedShapeSparseTensor(..) + | Self::Map(..) ) } @@ -587,13 +587,10 @@ impl DataType { #[inline] pub fn is_nested(&self) -> bool { - let p: DataType = self.to_physical(); + let p: Self = self.to_physical(); matches!( p, - DataType::List(..) - | DataType::FixedSizeList(..) - | DataType::Struct(..) - | DataType::Map(..) + Self::List(..) | Self::FixedSizeList(..) | Self::Struct(..) | Self::Map(..) ) } @@ -611,42 +608,40 @@ impl DataType { impl From<&ArrowType> for DataType { fn from(item: &ArrowType) -> Self { match item { - ArrowType::Null => DataType::Null, - ArrowType::Boolean => DataType::Boolean, - ArrowType::Int8 => DataType::Int8, - ArrowType::Int16 => DataType::Int16, - ArrowType::Int32 => DataType::Int32, - ArrowType::Int64 => DataType::Int64, - ArrowType::UInt8 => DataType::UInt8, - ArrowType::UInt16 => DataType::UInt16, - ArrowType::UInt32 => DataType::UInt32, - ArrowType::UInt64 => DataType::UInt64, + ArrowType::Null => Self::Null, + ArrowType::Boolean => Self::Boolean, + ArrowType::Int8 => Self::Int8, + ArrowType::Int16 => Self::Int16, + ArrowType::Int32 => Self::Int32, + ArrowType::Int64 => Self::Int64, + ArrowType::UInt8 => Self::UInt8, + ArrowType::UInt16 => Self::UInt16, + ArrowType::UInt32 => Self::UInt32, + ArrowType::UInt64 => Self::UInt64, // ArrowType::Float16 => DataType::Float16, - ArrowType::Float32 => DataType::Float32, - ArrowType::Float64 => DataType::Float64, - ArrowType::Timestamp(unit, timezone) => { - DataType::Timestamp(unit.into(), timezone.clone()) - } - ArrowType::Date32 => DataType::Date, - ArrowType::Date64 => DataType::Timestamp(TimeUnit::Milliseconds, None), + ArrowType::Float32 => Self::Float32, + ArrowType::Float64 => Self::Float64, + ArrowType::Timestamp(unit, timezone) => Self::Timestamp(unit.into(), timezone.clone()), + ArrowType::Date32 => Self::Date, + ArrowType::Date64 => Self::Timestamp(TimeUnit::Milliseconds, None), ArrowType::Time32(timeunit) | ArrowType::Time64(timeunit) => { - DataType::Time(timeunit.into()) + Self::Time(timeunit.into()) } - ArrowType::Duration(timeunit) => DataType::Duration(timeunit.into()), - ArrowType::FixedSizeBinary(size) => DataType::FixedSizeBinary(*size), - ArrowType::Binary | ArrowType::LargeBinary => DataType::Binary, - ArrowType::Utf8 | ArrowType::LargeUtf8 => DataType::Utf8, - ArrowType::Decimal(precision, scale) => DataType::Decimal128(*precision, *scale), + ArrowType::Duration(timeunit) => Self::Duration(timeunit.into()), + ArrowType::FixedSizeBinary(size) => Self::FixedSizeBinary(*size), + ArrowType::Binary | ArrowType::LargeBinary => Self::Binary, + ArrowType::Utf8 | ArrowType::LargeUtf8 => Self::Utf8, + ArrowType::Decimal(precision, scale) => Self::Decimal128(*precision, *scale), ArrowType::List(field) | ArrowType::LargeList(field) => { - DataType::List(Box::new(field.as_ref().data_type().into())) + Self::List(Box::new(field.as_ref().data_type().into())) } ArrowType::FixedSizeList(field, size) => { - DataType::FixedSizeList(Box::new(field.as_ref().data_type().into()), *size) + Self::FixedSizeList(Box::new(field.as_ref().data_type().into()), *size) } - ArrowType::Map(field, ..) => DataType::Map(Box::new(field.as_ref().data_type().into())), + ArrowType::Map(field, ..) => Self::Map(Box::new(field.as_ref().data_type().into())), ArrowType::Struct(fields) => { let fields: Vec = fields.iter().map(|fld| fld.into()).collect(); - DataType::Struct(fields) + Self::Struct(fields) } ArrowType::Extension(name, dtype, metadata) => { if name == DAFT_SUPER_EXTENSION_NAME { @@ -656,7 +651,7 @@ impl From<&ArrowType> for DataType { } } } - DataType::Extension( + Self::Extension( name.clone(), Box::new(dtype.as_ref().into()), metadata.clone(), @@ -673,9 +668,9 @@ impl From<&ImageMode> for DataType { use ImageMode::*; match mode { - L16 | LA16 | RGB16 | RGBA16 => DataType::UInt16, - RGB32F | RGBA32F => DataType::Float32, - _ => DataType::UInt8, + L16 | LA16 | RGB16 | RGBA16 => Self::UInt16, + RGB32F | RGBA32F => Self::Float32, + _ => Self::UInt8, } } } diff --git a/src/daft-schema/src/image_format.rs b/src/daft-schema/src/image_format.rs index f7f41a516f..93ec40963e 100644 --- a/src/daft-schema/src/image_format.rs +++ b/src/daft-schema/src/image_format.rs @@ -38,7 +38,7 @@ impl ImageFormat { } impl ImageFormat { - pub fn iterator() -> std::slice::Iter<'static, ImageFormat> { + pub fn iterator() -> std::slice::Iter<'static, Self> { use ImageFormat::*; static FORMATS: [ImageFormat; 5] = [PNG, JPEG, TIFF, GIF, BMP]; @@ -61,7 +61,7 @@ impl FromStr for ImageFormat { _ => Err(DaftError::TypeError(format!( "Image format {} is not supported; only the following formats are supported: {:?}", format, - ImageFormat::iterator().as_slice() + Self::iterator().as_slice() ))), } } diff --git a/src/daft-schema/src/image_mode.rs b/src/daft-schema/src/image_mode.rs index be2eebd8d3..9b41875ff0 100644 --- a/src/daft-schema/src/image_mode.rs +++ b/src/daft-schema/src/image_mode.rs @@ -75,12 +75,12 @@ impl ImageMode { "1" | "P" | "CMYK" | "YCbCr" | "LAB" | "HSV" | "I" | "F" | "PA" | "RGBX" | "RGBa" | "La" | "I;16" | "I;16L" | "I;16B" | "I;16N" | "BGR;15" | "BGR;16" | "BGR;24" => Err(DaftError::TypeError(format!( "PIL image mode {} is not supported; only the following modes are supported: {:?}", mode, - ImageMode::iterator().as_slice() + Self::iterator().as_slice() ))), _ => Err(DaftError::TypeError(format!( "Image mode {} is not a valid PIL image mode; see https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes for valid PIL image modes. Of these, only the following modes are supported by Daft: {:?}", mode, - ImageMode::iterator().as_slice() + Self::iterator().as_slice() ))), } } @@ -114,7 +114,7 @@ impl ImageMode { RGBA | RGBA16 | RGBA32F => 4, } } - pub fn iterator() -> std::slice::Iter<'static, ImageMode> { + pub fn iterator() -> std::slice::Iter<'static, Self> { use ImageMode::*; static MODES: [ImageMode; 10] = @@ -146,7 +146,7 @@ impl FromStr for ImageMode { _ => Err(DaftError::TypeError(format!( "Image mode {} is not supported; only the following modes are supported: {:?}", mode, - ImageMode::iterator().as_slice() + Self::iterator().as_slice() ))), } } diff --git a/src/daft-schema/src/python/datatype.rs b/src/daft-schema/src/python/datatype.rs index 9128eceec2..ceff5e18f3 100644 --- a/src/daft-schema/src/python/datatype.rs +++ b/src/daft-schema/src/python/datatype.rs @@ -14,7 +14,7 @@ pub struct PyTimeUnit { impl From for PyTimeUnit { fn from(value: TimeUnit) -> Self { - PyTimeUnit { timeunit: value } + Self { timeunit: value } } } @@ -217,7 +217,7 @@ impl PyDataType { } #[staticmethod] - pub fn r#struct(fields: IndexMap) -> Self { + pub fn r#struct(fields: IndexMap) -> Self { DataType::Struct( fields .into_iter() @@ -395,8 +395,8 @@ impl PyDataType { } pub fn is_equal(&self, other: Bound) -> PyResult { - if other.is_instance_of::() { - let other = other.extract::()?; + if other.is_instance_of::() { + let other = other.extract::()?; Ok(self.dtype == other.dtype) } else { Ok(false) @@ -423,7 +423,7 @@ impl_bincode_py_state_serialization!(PyDataType); impl From for PyDataType { fn from(value: DataType) -> Self { - PyDataType { dtype: value } + Self { dtype: value } } } diff --git a/src/daft-schema/src/python/field.rs b/src/daft-schema/src/python/field.rs index 2e39915843..8360d233dc 100644 --- a/src/daft-schema/src/python/field.rs +++ b/src/daft-schema/src/python/field.rs @@ -26,7 +26,7 @@ impl PyField { Ok(self.field.dtype.clone().into()) } - pub fn eq(&self, other: &PyField) -> PyResult { + pub fn eq(&self, other: &Self) -> PyResult { Ok(self.field.eq(&other.field)) } } @@ -35,7 +35,7 @@ impl_bincode_py_state_serialization!(PyField); impl From for PyField { fn from(field: Field) -> Self { - PyField { field } + Self { field } } } diff --git a/src/daft-schema/src/python/schema.rs b/src/daft-schema/src/python/schema.rs index 2aa39c9abc..3a13583ba8 100644 --- a/src/daft-schema/src/python/schema.rs +++ b/src/daft-schema/src/python/schema.rs @@ -46,12 +46,12 @@ impl PySchema { self.schema.names() } - pub fn union(&self, other: &PySchema) -> PyResult { + pub fn union(&self, other: &Self) -> PyResult { let new_schema = Arc::new(self.schema.union(&other.schema)?); Ok(new_schema.into()) } - pub fn eq(&self, other: &PySchema) -> PyResult { + pub fn eq(&self, other: &Self) -> PyResult { Ok(self.schema.fields.eq(&other.schema.fields)) } @@ -60,22 +60,20 @@ impl PySchema { } #[staticmethod] - pub fn from_field_name_and_types( - names_and_types: Vec<(String, PyDataType)>, - ) -> PyResult { + pub fn from_field_name_and_types(names_and_types: Vec<(String, PyDataType)>) -> PyResult { let fields = names_and_types .iter() .map(|(name, pydtype)| Field::new(name, pydtype.clone().into())) .collect(); let schema = schema::Schema::new(fields)?; - Ok(PySchema { + Ok(Self { schema: schema.into(), }) } #[staticmethod] - pub fn from_fields(fields: Vec) -> PyResult { - Ok(PySchema { + pub fn from_fields(fields: Vec) -> PyResult { + Ok(Self { schema: schema::Schema::new(fields.iter().map(|f| f.field.clone()).collect())?.into(), }) } @@ -96,7 +94,7 @@ impl PySchema { Ok(self.schema.truncated_table_string()) } - pub fn apply_hints(&self, hints: &PySchema) -> PyResult { + pub fn apply_hints(&self, hints: &Self) -> PyResult { let new_schema = Arc::new(self.schema.apply_hints(&hints.schema)?); Ok(new_schema.into()) } @@ -106,7 +104,7 @@ impl_bincode_py_state_serialization!(PySchema); impl From for PySchema { fn from(schema: schema::SchemaRef) -> Self { - PySchema { schema } + Self { schema } } } diff --git a/src/daft-schema/src/schema.rs b/src/daft-schema/src/schema.rs index 7b0328be15..04c0d88c71 100644 --- a/src/daft-schema/src/schema.rs +++ b/src/daft-schema/src/schema.rs @@ -41,10 +41,10 @@ impl Schema { } } - Ok(Schema { fields: map }) + Ok(Self { fields: map }) } - pub fn exclude>(&self, names: &[S]) -> DaftResult { + pub fn exclude>(&self, names: &[S]) -> DaftResult { let mut fields = IndexMap::new(); let names = names.iter().map(|s| s.as_ref()).collect::>(); for (name, field) in self.fields.iter() { @@ -53,11 +53,11 @@ impl Schema { } } - Ok(Schema { fields }) + Ok(Self { fields }) } pub fn empty() -> Self { - Schema { + Self { fields: indexmap::IndexMap::new(), } } @@ -96,7 +96,7 @@ impl Schema { self.fields.is_empty() } - pub fn union(&self, other: &Schema) -> DaftResult { + pub fn union(&self, other: &Self) -> DaftResult { let self_keys: HashSet<&String> = HashSet::from_iter(self.fields.keys()); let other_keys: HashSet<&String> = HashSet::from_iter(self.fields.keys()); match self_keys.difference(&other_keys).count() { @@ -105,7 +105,7 @@ impl Schema { for (k, v) in self.fields.iter().chain(other.fields.iter()) { fields.insert(k.clone(), v.clone()); } - Ok(Schema { fields }) + Ok(Self { fields }) } _ => Err(DaftError::ValueError( "Cannot union two schemas with overlapping keys".to_string(), @@ -113,7 +113,7 @@ impl Schema { } } - pub fn apply_hints(&self, hints: &Schema) -> DaftResult { + pub fn apply_hints(&self, hints: &Self) -> DaftResult { let applied_fields = self .fields .iter() @@ -123,7 +123,7 @@ impl Schema { }) .collect::>(); - Ok(Schema { + Ok(Self { fields: applied_fields, }) } @@ -238,7 +238,7 @@ impl Schema { } /// Returns a new schema with only the specified columns in the new schema - pub fn project>(self: Arc, columns: &[S]) -> DaftResult { + pub fn project>(self: Arc, columns: &[S]) -> DaftResult { let new_fields = columns .iter() .map(|i| { diff --git a/src/daft-schema/src/time_unit.rs b/src/daft-schema/src/time_unit.rs index 50cdcb1e57..d4b17b0e7c 100644 --- a/src/daft-schema/src/time_unit.rs +++ b/src/daft-schema/src/time_unit.rs @@ -16,19 +16,19 @@ impl TimeUnit { #![allow(clippy::wrong_self_convention)] pub fn to_arrow(&self) -> ArrowTimeUnit { match self { - TimeUnit::Nanoseconds => ArrowTimeUnit::Nanosecond, - TimeUnit::Microseconds => ArrowTimeUnit::Microsecond, - TimeUnit::Milliseconds => ArrowTimeUnit::Millisecond, - TimeUnit::Seconds => ArrowTimeUnit::Second, + Self::Nanoseconds => ArrowTimeUnit::Nanosecond, + Self::Microseconds => ArrowTimeUnit::Microsecond, + Self::Milliseconds => ArrowTimeUnit::Millisecond, + Self::Seconds => ArrowTimeUnit::Second, } } pub fn to_scale_factor(&self) -> i64 { match self { - TimeUnit::Seconds => 1, - TimeUnit::Milliseconds => 1000, - TimeUnit::Microseconds => 1_000_000, - TimeUnit::Nanoseconds => 1_000_000_000, + Self::Seconds => 1, + Self::Milliseconds => 1000, + Self::Microseconds => 1_000_000, + Self::Nanoseconds => 1_000_000_000, } } } @@ -36,10 +36,10 @@ impl TimeUnit { impl From<&ArrowTimeUnit> for TimeUnit { fn from(tu: &ArrowTimeUnit) -> Self { match tu { - ArrowTimeUnit::Nanosecond => TimeUnit::Nanoseconds, - ArrowTimeUnit::Microsecond => TimeUnit::Microseconds, - ArrowTimeUnit::Millisecond => TimeUnit::Milliseconds, - ArrowTimeUnit::Second => TimeUnit::Seconds, + ArrowTimeUnit::Nanosecond => Self::Nanoseconds, + ArrowTimeUnit::Microsecond => Self::Microseconds, + ArrowTimeUnit::Millisecond => Self::Milliseconds, + ArrowTimeUnit::Second => Self::Seconds, } } } diff --git a/src/daft-sketch/Cargo.toml b/src/daft-sketch/Cargo.toml index a1cb65528a..5612624cf3 100644 --- a/src/daft-sketch/Cargo.toml +++ b/src/daft-sketch/Cargo.toml @@ -8,6 +8,9 @@ serde_arrow = {version = "0.11.0", features = ["arrow2-0-17"]} sketches-ddsketch = {workspace = true} snafu = {workspace = true} +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-sketch" diff --git a/src/daft-sketch/src/arrow2_serde.rs b/src/daft-sketch/src/arrow2_serde.rs index cb36653e36..4213d0180c 100644 --- a/src/daft-sketch/src/arrow2_serde.rs +++ b/src/daft-sketch/src/arrow2_serde.rs @@ -19,7 +19,7 @@ impl From for DaftError { use Error::*; match value { DeserializationError { source } => { - DaftError::ComputeError(format!("Deserialization error: {}", source)) + Self::ComputeError(format!("Deserialization error: {}", source)) } } } diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 86f3baa11c..2b80dda42c 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -17,6 +17,9 @@ rstest = {workspace = true} [features] python = ["dep:pyo3", "common-error/python", "daft-functions/python", "daft-functions-json/python"] +[lints] +workspace = true + [package] name = "daft-sql" edition.workspace = true diff --git a/src/daft-sql/src/catalog.rs b/src/daft-sql/src/catalog.rs index 4da8ca6c8a..3495d32703 100644 --- a/src/daft-sql/src/catalog.rs +++ b/src/daft-sql/src/catalog.rs @@ -11,7 +11,7 @@ pub struct SQLCatalog { impl SQLCatalog { /// Create an empty catalog pub fn new() -> Self { - SQLCatalog { + Self { tables: HashMap::new(), } } @@ -27,7 +27,7 @@ impl SQLCatalog { } /// Copy from another catalog, using tables from other in case of conflict - pub fn copy_from(&mut self, other: &SQLCatalog) { + pub fn copy_from(&mut self, other: &Self) { for (name, plan) in other.tables.iter() { self.tables.insert(name.clone(), plan.clone()); } diff --git a/src/daft-sql/src/error.rs b/src/daft-sql/src/error.rs index d948c2bdb3..31f8a400ed 100644 --- a/src/daft-sql/src/error.rs +++ b/src/daft-sql/src/error.rs @@ -27,42 +27,42 @@ pub enum PlannerError { impl From for PlannerError { fn from(value: DaftError) -> Self { - PlannerError::DaftError { source: value } + Self::DaftError { source: value } } } impl From for PlannerError { fn from(value: TokenizerError) -> Self { - PlannerError::TokenizeError { source: value } + Self::TokenizeError { source: value } } } impl From for PlannerError { fn from(value: ParserError) -> Self { - PlannerError::SQLParserError { source: value } + Self::SQLParserError { source: value } } } impl PlannerError { pub fn column_not_found, B: Into>(column_name: A, relation: B) -> Self { - PlannerError::ColumnNotFound { + Self::ColumnNotFound { column_name: column_name.into(), relation: relation.into(), } } pub fn table_not_found>(table_name: S) -> Self { - PlannerError::TableNotFound { + Self::TableNotFound { message: table_name.into(), } } pub fn unsupported_sql(sql: String) -> Self { - PlannerError::UnsupportedSQL { message: sql } + Self::UnsupportedSQL { message: sql } } pub fn invalid_operation>(message: S) -> Self { - PlannerError::InvalidOperation { + Self::InvalidOperation { message: message.into(), } } @@ -112,7 +112,7 @@ impl From for DaftError { if let PlannerError::DaftError { source } = value { source } else { - DaftError::External(Box::new(value)) + Self::External(Box::new(value)) } } } diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index 74ee294fbc..695d3c9c79 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -34,7 +34,7 @@ impl SQLModule for SQLModuleAggs { impl SQLFunction for AggExpr { fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { // COUNT(*) needs a bit of extra handling, so we process that outside of `to_expr` - if let AggExpr::Count(_, _) = self { + if let Self::Count(_, _) = self { handle_count(inputs, planner) } else { let inputs = self.args_to_expr_unnamed(inputs, planner)?; diff --git a/src/daft-sql/src/modules/partitioning.rs b/src/daft-sql/src/modules/partitioning.rs index 589c298e2f..e833edd51d 100644 --- a/src/daft-sql/src/modules/partitioning.rs +++ b/src/daft-sql/src/modules/partitioning.rs @@ -32,19 +32,11 @@ impl SQLFunction for PartitioningExpr { planner: &crate::planner::SQLPlanner, ) -> crate::error::SQLPlannerResult { match self { - PartitioningExpr::Years => { - partitioning_helper(args, planner, "years", partitioning::years) - } - PartitioningExpr::Months => { - partitioning_helper(args, planner, "months", partitioning::months) - } - PartitioningExpr::Days => { - partitioning_helper(args, planner, "days", partitioning::days) - } - PartitioningExpr::Hours => { - partitioning_helper(args, planner, "hours", partitioning::hours) - } - PartitioningExpr::IcebergBucket(_) => { + Self::Years => partitioning_helper(args, planner, "years", partitioning::years), + Self::Months => partitioning_helper(args, planner, "months", partitioning::months), + Self::Days => partitioning_helper(args, planner, "days", partitioning::days), + Self::Hours => partitioning_helper(args, planner, "hours", partitioning::hours), + Self::IcebergBucket(_) => { ensure!(args.len() == 2, "iceberg_bucket takes exactly 2 arguments"); let input = planner.plan_function_arg(&args[0])?; let n = planner @@ -68,7 +60,7 @@ impl SQLFunction for PartitioningExpr { Ok(partitioning::iceberg_bucket(input, n)) } - PartitioningExpr::IcebergTruncate(_) => { + Self::IcebergTruncate(_) => { ensure!( args.len() == 2, "iceberg_truncate takes exactly 2 arguments" diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index b58f637783..2aefab9f96 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -36,7 +36,7 @@ pub(crate) struct Relation { impl Relation { pub fn new(inner: LogicalPlanBuilder, name: String) -> Self { - Relation { inner, name } + Self { inner, name } } pub(crate) fn schema(&self) -> SchemaRef { self.inner.schema() @@ -50,7 +50,7 @@ pub struct SQLPlanner { impl Default for SQLPlanner { fn default() -> Self { - SQLPlanner { + Self { catalog: SQLCatalog::new(), current_relation: None, } @@ -59,7 +59,7 @@ impl Default for SQLPlanner { impl SQLPlanner { pub fn new(context: SQLCatalog) -> Self { - SQLPlanner { + Self { catalog: context, current_relation: None, } diff --git a/src/daft-sql/src/python.rs b/src/daft-sql/src/python.rs index 216aaba3d8..9201b7fccc 100644 --- a/src/daft-sql/src/python.rs +++ b/src/daft-sql/src/python.rs @@ -34,7 +34,7 @@ impl PyCatalog { /// Construct an empty PyCatalog. #[staticmethod] pub fn new() -> Self { - PyCatalog { + Self { catalog: SQLCatalog::new(), } } @@ -46,7 +46,7 @@ impl PyCatalog { } /// Copy from another catalog, using tables from other in case of conflict - pub fn copy_from(&mut self, other: &PyCatalog) { + pub fn copy_from(&mut self, other: &Self) { self.catalog.copy_from(&other.catalog); } diff --git a/src/daft-stats/Cargo.toml b/src/daft-stats/Cargo.toml index f3de4bbec1..7cea2fe37d 100644 --- a/src/daft-stats/Cargo.toml +++ b/src/daft-stats/Cargo.toml @@ -10,6 +10,9 @@ snafu = {workspace = true} [features] python = ["common-error/python", "daft-core/python", "daft-dsl/python", "daft-table/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-stats" diff --git a/src/daft-stats/src/column_stats/comparison.rs b/src/daft-stats/src/column_stats/comparison.rs index 1d3d923666..7e2021744c 100644 --- a/src/daft-stats/src/column_stats/comparison.rs +++ b/src/daft-stats/src/column_stats/comparison.rs @@ -6,20 +6,15 @@ use snafu::ResultExt; use super::ColumnRangeStatistics; use crate::DaftCoreComputeSnafu; -impl DaftCompare<&ColumnRangeStatistics> for ColumnRangeStatistics { - type Output = crate::Result; - fn equal(&self, rhs: &ColumnRangeStatistics) -> Self::Output { +impl DaftCompare<&Self> for ColumnRangeStatistics { + type Output = crate::Result; + fn equal(&self, rhs: &Self) -> Self::Output { // lower_bound: do they exactly overlap // upper_bound: is there any overlap match (self, rhs) { - (ColumnRangeStatistics::Missing, _) | (_, ColumnRangeStatistics::Missing) => { - Ok(ColumnRangeStatistics::Missing) - } - ( - ColumnRangeStatistics::Loaded(s_lower, s_upper), - ColumnRangeStatistics::Loaded(r_lower, r_upper), - ) => { + (Self::Missing, _) | (_, Self::Missing) => Ok(Self::Missing), + (Self::Loaded(s_lower, s_upper), Self::Loaded(r_lower, r_upper)) => { let exactly_overlap = (s_lower.equal(r_lower).context(DaftCoreComputeSnafu)?) .and(&s_upper.equal(r_upper).context(DaftCoreComputeSnafu)?) .context(DaftCoreComputeSnafu)? @@ -40,27 +35,22 @@ impl DaftCompare<&ColumnRangeStatistics> for ColumnRangeStatistics { .or(&rhs_lower_in_self_bounds) .context(DaftCoreComputeSnafu)? .into_series(); - Ok(ColumnRangeStatistics::Loaded(exactly_overlap, any_overlap)) + Ok(Self::Loaded(exactly_overlap, any_overlap)) } } } - fn not_equal(&self, rhs: &ColumnRangeStatistics) -> Self::Output { + fn not_equal(&self, rhs: &Self) -> Self::Output { // invert of equal self.equal(rhs)?.not() } - fn gt(&self, rhs: &ColumnRangeStatistics) -> Self::Output { + fn gt(&self, rhs: &Self) -> Self::Output { // lower_bound: True greater (self.lower > rhs.upper) // upper_bound: some value that can be greater (self.upper > rhs.lower) match (self, rhs) { - (ColumnRangeStatistics::Missing, _) | (_, ColumnRangeStatistics::Missing) => { - Ok(ColumnRangeStatistics::Missing) - } - ( - ColumnRangeStatistics::Loaded(s_lower, s_upper), - ColumnRangeStatistics::Loaded(r_lower, r_upper), - ) => { + (Self::Missing, _) | (_, Self::Missing) => Ok(Self::Missing), + (Self::Loaded(s_lower, s_upper), Self::Loaded(r_lower, r_upper)) => { let maybe_greater = s_upper .gt(r_lower) .context(DaftCoreComputeSnafu)? @@ -69,20 +59,15 @@ impl DaftCompare<&ColumnRangeStatistics> for ColumnRangeStatistics { .gt(r_upper) .context(DaftCoreComputeSnafu)? .into_series(); - Ok(ColumnRangeStatistics::Loaded(always_greater, maybe_greater)) + Ok(Self::Loaded(always_greater, maybe_greater)) } } } - fn gte(&self, rhs: &ColumnRangeStatistics) -> Self::Output { + fn gte(&self, rhs: &Self) -> Self::Output { match (self, rhs) { - (ColumnRangeStatistics::Missing, _) | (_, ColumnRangeStatistics::Missing) => { - Ok(ColumnRangeStatistics::Missing) - } - ( - ColumnRangeStatistics::Loaded(s_lower, s_upper), - ColumnRangeStatistics::Loaded(r_lower, r_upper), - ) => { + (Self::Missing, _) | (_, Self::Missing) => Ok(Self::Missing), + (Self::Loaded(s_lower, s_upper), Self::Loaded(r_lower, r_upper)) => { let maybe_gte = s_upper .gte(r_lower) .context(DaftCoreComputeSnafu)? @@ -91,23 +76,18 @@ impl DaftCompare<&ColumnRangeStatistics> for ColumnRangeStatistics { .gte(r_upper) .context(DaftCoreComputeSnafu)? .into_series(); - Ok(ColumnRangeStatistics::Loaded(always_gte, maybe_gte)) + Ok(Self::Loaded(always_gte, maybe_gte)) } } } - fn lt(&self, rhs: &ColumnRangeStatistics) -> Self::Output { + fn lt(&self, rhs: &Self) -> Self::Output { // lower_bound: True less than (self.upper < rhs.lower) // upper_bound: some value that can be less than (self.lower < rhs.upper) match (self, rhs) { - (ColumnRangeStatistics::Missing, _) | (_, ColumnRangeStatistics::Missing) => { - Ok(ColumnRangeStatistics::Missing) - } - ( - ColumnRangeStatistics::Loaded(s_lower, s_upper), - ColumnRangeStatistics::Loaded(r_lower, r_upper), - ) => { + (Self::Missing, _) | (_, Self::Missing) => Ok(Self::Missing), + (Self::Loaded(s_lower, s_upper), Self::Loaded(r_lower, r_upper)) => { let maybe_lt = s_lower .lt(r_upper) .context(DaftCoreComputeSnafu)? @@ -116,20 +96,15 @@ impl DaftCompare<&ColumnRangeStatistics> for ColumnRangeStatistics { .lt(r_lower) .context(DaftCoreComputeSnafu)? .into_series(); - Ok(ColumnRangeStatistics::Loaded(always_lt, maybe_lt)) + Ok(Self::Loaded(always_lt, maybe_lt)) } } } - fn lte(&self, rhs: &ColumnRangeStatistics) -> Self::Output { + fn lte(&self, rhs: &Self) -> Self::Output { match (self, rhs) { - (ColumnRangeStatistics::Missing, _) | (_, ColumnRangeStatistics::Missing) => { - Ok(ColumnRangeStatistics::Missing) - } - ( - ColumnRangeStatistics::Loaded(s_lower, s_upper), - ColumnRangeStatistics::Loaded(r_lower, r_upper), - ) => { + (Self::Missing, _) | (_, Self::Missing) => Ok(Self::Missing), + (Self::Loaded(s_lower, s_upper), Self::Loaded(r_lower, r_upper)) => { let maybe_lte = s_lower .lte(r_upper) .context(DaftCoreComputeSnafu)? @@ -138,7 +113,7 @@ impl DaftCompare<&ColumnRangeStatistics> for ColumnRangeStatistics { .lte(r_lower) .context(DaftCoreComputeSnafu)? .into_series(); - Ok(ColumnRangeStatistics::Loaded(always_lte, maybe_lte)) + Ok(Self::Loaded(always_lte, maybe_lte)) } } } @@ -147,13 +122,8 @@ impl DaftCompare<&ColumnRangeStatistics> for ColumnRangeStatistics { impl ColumnRangeStatistics { pub fn union(&self, rhs: &Self) -> crate::Result { match (self, rhs) { - (ColumnRangeStatistics::Missing, _) | (_, ColumnRangeStatistics::Missing) => { - Ok(ColumnRangeStatistics::Missing) - } - ( - ColumnRangeStatistics::Loaded(s_lower, s_upper), - ColumnRangeStatistics::Loaded(r_lower, r_upper), - ) => { + (Self::Missing, _) | (_, Self::Missing) => Ok(Self::Missing), + (Self::Loaded(s_lower, s_upper), Self::Loaded(r_lower, r_upper)) => { let new_min = s_lower.if_else( r_lower, &(s_lower.lt(r_lower)) @@ -167,7 +137,7 @@ impl ColumnRangeStatistics { .into_series(), ); - Ok(ColumnRangeStatistics::Loaded( + Ok(Self::Loaded( new_min.context(DaftCoreComputeSnafu)?, new_max.context(DaftCoreComputeSnafu)?, )) diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index e8dc82f2f8..df96daa373 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -42,13 +42,13 @@ impl ColumnRangeStatistics { assert_eq!(l.data_type(), u.data_type(), ""); // If creating on incompatible types, default to `Missing` - if !ColumnRangeStatistics::supports_dtype(l.data_type()) { - return Ok(ColumnRangeStatistics::Missing); + if !Self::supports_dtype(l.data_type()) { + return Ok(Self::Missing); } - Ok(ColumnRangeStatistics::Loaded(l, u)) + Ok(Self::Loaded(l, u)) } - _ => Ok(ColumnRangeStatistics::Missing), + _ => Ok(Self::Missing), } } @@ -148,18 +148,16 @@ impl ColumnRangeStatistics { pub fn cast(&self, dtype: &DataType) -> crate::Result { match self { // `Missing` is casted to `Missing` - ColumnRangeStatistics::Missing => Ok(ColumnRangeStatistics::Missing), + Self::Missing => Ok(Self::Missing), // If the type to cast to matches the current type exactly, short-circuit the logic here. This should be the // most common case (e.g. parsing a Parquet file with the same types as the inferred types) - ColumnRangeStatistics::Loaded(l, r) if l.data_type() == dtype => { - Ok(ColumnRangeStatistics::Loaded(l.clone(), r.clone())) - } + Self::Loaded(l, r) if l.data_type() == dtype => Ok(Self::Loaded(l.clone(), r.clone())), // Only certain types are allowed to be casted in the context of ColumnRangeStatistics // as casting may not correctly preserve ordering of elements. We allow-list some type combinations // but for most combinations, we will default to `ColumnRangeStatistics::Missing`. - ColumnRangeStatistics::Loaded(l, r) => { + Self::Loaded(l, r) => { match (l.data_type(), dtype) { // Int casting to higher bitwidths (DataType::Int8, DataType::Int16) | @@ -187,11 +185,11 @@ impl ColumnRangeStatistics { (DataType::Int64, DataType::Timestamp(..)) | // Binary to Utf8 (DataType::Binary, DataType::Utf8) - => Ok(ColumnRangeStatistics::Loaded( + => Ok(Self::Loaded( l.cast(dtype).context(DaftCoreComputeSnafu)?, r.cast(dtype).context(DaftCoreComputeSnafu)?, )), - _ => Ok(ColumnRangeStatistics::Missing) + _ => Ok(Self::Missing) } } } @@ -240,7 +238,7 @@ pub enum Error { impl From for crate::Error { fn from(value: Error) -> Self { - crate::Error::MissingStatistics { source: value } + Self::MissingStatistics { source: value } } } diff --git a/src/daft-stats/src/lib.rs b/src/daft-stats/src/lib.rs index 3bf362782f..f73a05101a 100644 --- a/src/daft-stats/src/lib.rs +++ b/src/daft-stats/src/lib.rs @@ -40,7 +40,7 @@ impl From for DaftError { fn from(value: Error) -> Self { match value { Error::DaftCoreCompute { source } => source, - _ => DaftError::External(value.into()), + _ => Self::External(value.into()), } } } diff --git a/src/daft-stats/src/table_stats.rs b/src/daft-stats/src/table_stats.rs index 40b2e220c2..0fff747c98 100644 --- a/src/daft-stats/src/table_stats.rs +++ b/src/daft-stats/src/table_stats.rs @@ -31,7 +31,7 @@ impl TableStatistics { let stats = ColumnRangeStatistics::new(Some(col.slice(0, 1)?), Some(col.slice(1, 2)?))?; columns.insert(name, stats); } - Ok(TableStatistics { columns }) + Ok(Self { columns }) } pub fn from_table(table: &Table) -> Self { @@ -41,7 +41,7 @@ impl TableStatistics { let stats = ColumnRangeStatistics::from_series(col); columns.insert(name, stats); } - TableStatistics { columns } + Self { columns } } pub fn union(&self, other: &Self) -> crate::Result { @@ -61,7 +61,7 @@ impl TableStatistics { }?; columns.insert(col.clone(), res_col); } - Ok(TableStatistics { columns }) + Ok(Self { columns }) } pub fn eval_expression_list( @@ -151,7 +151,7 @@ impl TableStatistics { } } - pub fn cast_to_schema(&self, schema: SchemaRef) -> crate::Result { + pub fn cast_to_schema(&self, schema: SchemaRef) -> crate::Result { self.cast_to_schema_with_fill(schema, None) } @@ -159,7 +159,7 @@ impl TableStatistics { &self, schema: SchemaRef, fill_map: Option<&HashMap<&str, ExprRef>>, - ) -> crate::Result { + ) -> crate::Result { let mut columns = IndexMap::new(); for (field_name, field) in schema.fields.iter() { let crs = match self.columns.get(field_name) { @@ -175,7 +175,7 @@ impl TableStatistics { }; columns.insert(field_name.clone(), crs); } - Ok(TableStatistics { columns }) + Ok(Self { columns }) } } diff --git a/src/daft-table/Cargo.toml b/src/daft-table/Cargo.toml index 56931e04d5..5682fa41a2 100644 --- a/src/daft-table/Cargo.toml +++ b/src/daft-table/Cargo.toml @@ -17,6 +17,9 @@ serde = {workspace = true} [features] python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-dsl/python", "common-arrow-ffi/python", "common-display/python", "daft-image/python"] +[lints] +workspace = true + [package] edition = {workspace = true} name = "daft-table" diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 100ca31726..3669fda3f5 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -94,7 +94,7 @@ impl Table { }) .collect(); - Ok(Table::new_unchecked(schema, columns?, num_rows)) + Ok(Self::new_unchecked(schema, columns?, num_rows)) } /// Create a new [`Table`] and validate against `num_rows` @@ -121,7 +121,7 @@ impl Table { } } - Ok(Table::new_unchecked(schema, columns, num_rows)) + Ok(Self::new_unchecked(schema, columns, num_rows)) } /// Create a new [`Table`] without any validations @@ -135,7 +135,7 @@ impl Table { columns: Vec, num_rows: usize, ) -> Self { - Table { + Self { schema: schema.into(), columns, num_rows, @@ -149,7 +149,7 @@ impl Table { let series = Series::empty(field_name, &field.dtype); columns.push(series) } - Ok(Table::new_unchecked(schema, columns, 0)) + Ok(Self::new_unchecked(schema, columns, 0)) } /// Create a Table from a set of columns. @@ -179,7 +179,7 @@ impl Table { } } - Ok(Table::new_unchecked(schema, columns, num_rows)) + Ok(Self::new_unchecked(schema, columns, num_rows)) } pub fn num_columns(&self) -> usize { @@ -202,12 +202,12 @@ impl Table { let new_series: DaftResult> = self.columns.iter().map(|s| s.slice(start, end)).collect(); let new_num_rows = self.len().min(end - start); - Table::new_with_size(self.schema.clone(), new_series?, new_num_rows) + Self::new_with_size(self.schema.clone(), new_series?, new_num_rows) } pub fn head(&self, num: usize) -> DaftResult { if num >= self.len() { - return Ok(Table::new_unchecked( + return Ok(Self::new_unchecked( self.schema.clone(), self.columns.clone(), self.len(), @@ -346,15 +346,15 @@ impl Table { mask.len() - num_filtered }; - Table::new_with_size(self.schema.clone(), new_series?, num_rows) + Self::new_with_size(self.schema.clone(), new_series?, num_rows) } pub fn take(&self, idx: &Series) -> DaftResult { let new_series: DaftResult> = self.columns.iter().map(|s| s.take(idx)).collect(); - Table::new_with_size(self.schema.clone(), new_series?, idx.len()) + Self::new_with_size(self.schema.clone(), new_series?, idx.len()) } - pub fn concat>(tables: &[T]) -> DaftResult { + pub fn concat>(tables: &[T]) -> DaftResult { if tables.is_empty() { return Err(DaftError::ValueError( "Need at least 1 Table to perform concat".to_string(), @@ -384,14 +384,14 @@ impl Table { new_series.push(Series::concat(series_to_cat.as_slice())?); } - Table::new_with_size( + Self::new_with_size( first_table.schema.clone(), new_series, tables.iter().map(|t| t.as_ref().len()).sum(), ) } - pub fn union(&self, other: &Table) -> DaftResult { + pub fn union(&self, other: &Self) -> DaftResult { if self.num_rows != other.num_rows { return Err(DaftError::ValueError(format!( "Cannot union tables of length {} and {}", @@ -625,7 +625,7 @@ impl Table { (true, _) => result_series.iter().map(|s| s.len()).max().unwrap(), }; - Table::new_with_broadcast(new_schema, result_series, num_rows) + Self::new_with_broadcast(new_schema, result_series, num_rows) } pub fn as_physical(&self) -> DaftResult { @@ -635,7 +635,7 @@ impl Table { .map(|s| s.as_physical()) .collect::>>()?; let new_schema = Schema::new(new_series.iter().map(|s| s.field().clone()).collect())?; - Table::new_with_size(new_schema, new_series, self.len()) + Self::new_with_size(new_schema, new_series, self.len()) } pub fn cast_to_schema(&self, schema: &Schema) -> DaftResult { @@ -781,8 +781,8 @@ impl Display for Table { } } -impl AsRef for Table { - fn as_ref(&self) -> &Table { +impl AsRef for Table { + fn as_ref(&self) -> &Self { self } } diff --git a/src/daft-table/src/ops/agg.rs b/src/daft-table/src/ops/agg.rs index 33bbf635b6..70abdf69f4 100644 --- a/src/daft-table/src/ops/agg.rs +++ b/src/daft-table/src/ops/agg.rs @@ -5,7 +5,7 @@ use daft_dsl::{functions::FunctionExpr, AggExpr, Expr, ExprRef}; use crate::Table; impl Table { - pub fn agg(&self, to_agg: &[ExprRef], group_by: &[ExprRef]) -> DaftResult
{ + pub fn agg(&self, to_agg: &[ExprRef], group_by: &[ExprRef]) -> DaftResult { // Dispatch depending on whether we're doing groupby or just a global agg. match group_by.len() { 0 => self.agg_global(to_agg), @@ -13,11 +13,11 @@ impl Table { } } - pub fn agg_global(&self, to_agg: &[ExprRef]) -> DaftResult
{ + pub fn agg_global(&self, to_agg: &[ExprRef]) -> DaftResult { self.eval_expression_list(to_agg) } - pub fn agg_groupby(&self, to_agg: &[ExprRef], group_by: &[ExprRef]) -> DaftResult
{ + pub fn agg_groupby(&self, to_agg: &[ExprRef], group_by: &[ExprRef]) -> DaftResult { let agg_exprs = to_agg .iter() .map(|e| match e.as_ref() { @@ -68,7 +68,7 @@ impl Table { func: &FunctionExpr, inputs: &[ExprRef], group_by: &[ExprRef], - ) -> DaftResult
{ + ) -> DaftResult { use daft_core::array::ops::IntoGroups; use daft_dsl::functions::python::PythonUDF; @@ -100,7 +100,7 @@ impl Table { // Take fast path short circuit if there is only 1 group let (groupkeys_table, grouped_col) = if groupvals_indices.is_empty() { - let empty_groupkeys_table = Table::empty(Some(groupby_table.schema.clone()))?; + let empty_groupkeys_table = Self::empty(Some(groupby_table.schema.clone()))?; let empty_udf_output_col = Series::empty( evaluated_inputs .first() @@ -151,7 +151,7 @@ impl Table { .collect::>>()?; // Combine the broadcasted group keys into a Table - Table::from_nonempty_columns(broadcasted_groupkeys)? + Self::from_nonempty_columns(broadcasted_groupkeys)? }; Ok((broadcasted_groupkeys_table, evaluated_grouped_col)) @@ -162,7 +162,7 @@ impl Table { let concatenated_grouped_col = Series::concat(series_refs.as_slice())?; let table_refs = grouped_results.iter().map(|(t, _)| t).collect::>(); - let concatenated_groupkeys_table = Table::concat(table_refs.as_slice())?; + let concatenated_groupkeys_table = Self::concat(table_refs.as_slice())?; (concatenated_groupkeys_table, concatenated_grouped_col) }; diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 2d5e80dae5..5bb4f77d4b 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -178,6 +178,6 @@ impl Table { let num_rows = lidx.len(); join_series = add_non_join_key_columns(self, right, lidx, ridx, join_series)?; - Table::new_with_size(join_schema, join_series, num_rows) + Self::new_with_size(join_schema, join_series, num_rows) } } diff --git a/src/daft-table/src/ops/pivot.rs b/src/daft-table/src/ops/pivot.rs index a0b07d20cd..4418d4365e 100644 --- a/src/daft-table/src/ops/pivot.rs +++ b/src/daft-table/src/ops/pivot.rs @@ -73,7 +73,7 @@ impl Table { pivot_col: ExprRef, values_col: ExprRef, names: Vec, - ) -> DaftResult
{ + ) -> DaftResult { // This function pivots the table based on the given group_by, pivot, and values column. // // At a high level this function does two things: diff --git a/src/daft-table/src/ops/sort.rs b/src/daft-table/src/ops/sort.rs index de082b1970..20e56c3b0b 100644 --- a/src/daft-table/src/ops/sort.rs +++ b/src/daft-table/src/ops/sort.rs @@ -5,7 +5,7 @@ use daft_dsl::ExprRef; use crate::Table; impl Table { - pub fn sort(&self, sort_keys: &[ExprRef], descending: &[bool]) -> DaftResult
{ + pub fn sort(&self, sort_keys: &[ExprRef], descending: &[bool]) -> DaftResult { let argsort = self.argsort(sort_keys, descending)?; self.take(&argsort) } diff --git a/src/daft-table/src/ops/unpivot.rs b/src/daft-table/src/ops/unpivot.rs index 37fba415fc..5b43777417 100644 --- a/src/daft-table/src/ops/unpivot.rs +++ b/src/daft-table/src/ops/unpivot.rs @@ -54,6 +54,6 @@ impl Table { ])?)?; let unpivot_series = [ids_series, vec![variable_series, value_series]].concat(); - Table::new_with_size(unpivot_schema, unpivot_series, unpivoted_len) + Self::new_with_size(unpivot_schema, unpivot_series, unpivoted_len) } } diff --git a/src/daft-table/src/python.rs b/src/daft-table/src/python.rs index 728383d23c..3bacbcf019 100644 --- a/src/daft-table/src/python.rs +++ b/src/daft-table/src/python.rs @@ -261,7 +261,7 @@ impl PyTable { .partition_by_hash(exprs.as_slice(), num_partitions as usize)? .into_iter() .map(|t| t.into()) - .collect::>()) + .collect::>()) }) } @@ -288,7 +288,7 @@ impl PyTable { .partition_by_random(num_partitions as usize, seed as u64)? .into_iter() .map(|t| t.into()) - .collect::>()) + .collect::>()) }) } @@ -306,7 +306,7 @@ impl PyTable { .partition_by_range(exprs.as_slice(), &boundaries.table, descending.as_slice())? .into_iter() .map(|t| t.into()) - .collect::>()) + .collect::>()) }) } @@ -318,10 +318,7 @@ impl PyTable { let exprs: Vec = partition_keys.into_iter().map(|e| e.into()).collect(); py.allow_threads(|| { let (tables, values) = self.table.partition_by_value(exprs.as_slice())?; - let pytables = tables - .into_iter() - .map(|t| t.into()) - .collect::>(); + let pytables = tables.into_iter().map(|t| t.into()).collect::>(); let values = values.into(); Ok((pytables, values)) }) @@ -407,7 +404,7 @@ impl PyTable { ) -> PyResult { let table = ffi::record_batches_to_table(py, record_batches.as_slice(), schema.schema.clone())?; - Ok(PyTable { table }) + Ok(Self { table }) } #[staticmethod] @@ -438,7 +435,7 @@ impl PyTable { } } - Ok(PyTable { + Ok(Self { table: Table::new_with_broadcast(Schema::new(fields)?, columns, num_rows)?, }) } @@ -462,7 +459,7 @@ impl PyTable { impl From
for PyTable { fn from(value: Table) -> Self { - PyTable { table: value } + Self { table: value } } } diff --git a/src/hyperloglog/Cargo.toml b/src/hyperloglog/Cargo.toml index fa430673f0..fa299abdea 100644 --- a/src/hyperloglog/Cargo.toml +++ b/src/hyperloglog/Cargo.toml @@ -1,5 +1,8 @@ [dependencies] +[lints] +workspace = true + [package] name = "hyperloglog" edition.workspace = true From a8602a2d0b0f5f5f35472218a49bd4dde8ceedda Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Wed, 25 Sep 2024 12:56:02 -0700 Subject: [PATCH 30/35] [CHORE] Classify throttle and internal errors as Retryable in Python (#2914) --- daft/exceptions.py | 16 +++++++ src/common/error/src/error.rs | 4 ++ src/common/error/src/python.rs | 4 ++ src/daft-io/src/lib.rs | 10 ++++ src/daft-io/src/s3_like.rs | 84 ++++++++++++++++++++++------------ 5 files changed, 89 insertions(+), 29 deletions(-) diff --git a/daft/exceptions.py b/daft/exceptions.py index 6909d9300f..121b06938b 100644 --- a/daft/exceptions.py +++ b/daft/exceptions.py @@ -52,3 +52,19 @@ class SocketError(DaftTransientError): """ pass + + +class ThrottleError(DaftTransientError): + """Daft Throttle Error + Daft client had a throttle error while making request to server. + """ + + pass + + +class MiscTransientError(DaftTransientError): + """Daft Misc Transient Error + Daft client had a Misc Transient Error while making request to server. + """ + + pass diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index c3ea90f5f2..0513d3e112 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -34,6 +34,10 @@ pub enum DaftError { ByteStreamError(GenericError), #[error("SocketError {0}")] SocketError(GenericError), + #[error("ThrottledIo {0}")] + ThrottledIo(GenericError), + #[error("MiscTransient {0}")] + MiscTransient(GenericError), #[error("DaftError::External {0}")] External(GenericError), #[error("DaftError::SerdeJsonError {0}")] diff --git a/src/common/error/src/python.rs b/src/common/error/src/python.rs index 917dafdc78..08ad13d8d6 100644 --- a/src/common/error/src/python.rs +++ b/src/common/error/src/python.rs @@ -8,6 +8,8 @@ import_exception!(daft.exceptions, ConnectTimeoutError); import_exception!(daft.exceptions, ReadTimeoutError); import_exception!(daft.exceptions, ByteStreamError); import_exception!(daft.exceptions, SocketError); +import_exception!(daft.exceptions, ThrottleError); +import_exception!(daft.exceptions, MiscTransientError); impl std::convert::From for pyo3::PyErr { fn from(err: DaftError) -> Self { @@ -21,6 +23,8 @@ impl std::convert::From for pyo3::PyErr { DaftError::ReadTimeout(err) => ReadTimeoutError::new_err(err.to_string()), DaftError::ByteStreamError(err) => ByteStreamError::new_err(err.to_string()), DaftError::SocketError(err) => SocketError::new_err(err.to_string()), + DaftError::ThrottledIo(err) => ThrottleError::new_err(err.to_string()), + DaftError::MiscTransient(err) => MiscTransientError::new_err(err.to_string()), _ => DaftCoreException::new_err(err.to_string()), } } diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 6fdaac2368..8d87f5b767 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -97,6 +97,12 @@ pub enum Error { ))] SocketError { path: String, source: DynError }, + #[snafu(display("Throttled when trying to read {}\nDetails:\n{:?}", path, source))] + Throttled { path: String, source: DynError }, + + #[snafu(display("Misc Transient error trying to read {}\nDetails:\n{:?}", path, source))] + MiscTransient { path: String, source: DynError }, + #[snafu(display("Unable to convert URL \"{}\" to path", path))] InvalidUrl { path: String, @@ -150,6 +156,8 @@ impl From for DaftError { ReadTimeout { .. } => Self::ReadTimeout(err.into()), UnableToReadBytes { .. } => Self::ByteStreamError(err.into()), SocketError { .. } => Self::SocketError(err.into()), + Throttled { .. } => Self::ThrottledIo(err.into()), + MiscTransient { .. } => Self::MiscTransient(err.into()), // We have to repeat everything above for the case we have an Arc since we can't move the error. CachedError { ref source } => match source.as_ref() { NotFound { path, source: _ } => Self::FileNotFound { @@ -160,6 +168,8 @@ impl From for DaftError { ReadTimeout { .. } => Self::ReadTimeout(err.into()), UnableToReadBytes { .. } => Self::ByteStreamError(err.into()), SocketError { .. } => Self::SocketError(err.into()), + Throttled { .. } => Self::ThrottledIo(err.into()), + MiscTransient { .. } => Self::MiscTransient(err.into()), _ => Self::External(err.into()), }, _ => Self::External(err.into()), diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 2766011ae7..e6eb829a78 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -10,8 +10,10 @@ use aws_credential_types::{ cache::{CredentialsCache, ProvideCachedCredentials, SharedCredentialsCache}, provider::error::CredentialsError, }; -use aws_sdk_s3 as s3; -use aws_sdk_s3::{operation::put_object::PutObjectError, primitives::ByteStreamError}; +use aws_sdk_s3::{ + self as s3, error::ProvideErrorMetadata, operation::put_object::PutObjectError, + primitives::ByteStreamError, +}; use aws_sig_auth::signer::SigningRequirements; use aws_smithy_async::rt::sleep::TokioSleep; use common_io_config::S3Config; @@ -118,10 +120,51 @@ enum Error { UploadsCannotBeAnonymous {}, } +/// List of AWS error codes that are due to throttling +/// https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html#ErrorCodeList +const THROTTLING_ERRORS: &[&str] = &[ + "Throttling", + "ThrottlingException", + "ThrottledException", + "RequestThrottledException", + "TooManyRequestsException", + "ProvisionedThroughputExceededException", + "TransactionInProgressException", + "RequestLimitExceeded", + "BandwidthLimitExceeded", + "LimitExceededException", + "RequestThrottled", + "SlowDown", + "PriorRequestNotComplete", + "EC2ThrottledException", +]; + impl From for super::Error { fn from(error: Error) -> Self { use Error::*; + fn classify_unhandled_error< + E: std::error::Error + ProvideErrorMetadata + Send + Sync + 'static, + >( + path: String, + err: E, + ) -> super::Error { + match err.code() { + Some("InternalError") => super::Error::MiscTransient { + path, + source: err.into(), + }, + Some(code) if THROTTLING_ERRORS.contains(&code) => super::Error::Throttled { + path, + source: err.into(), + }, + _ => super::Error::Unhandled { + path, + msg: DisplayErrorContext(err).to_string(), + }, + } + } + match error { UnableToOpenFile { path, source } => match source { SdkError::TimeoutError(_) => Self::ReadTimeout { @@ -140,25 +183,20 @@ impl From for super::Error { source: source.into(), } } else { - Self::UnableToOpenFile { + // who knows what happened here during dispatch, let's just tell the user it's transient + Self::MiscTransient { path, source: source.into(), } } } + _ => match source.into_service_error() { GetObjectError::NoSuchKey(no_such_key) => Self::NotFound { path, source: no_such_key.into(), }, - GetObjectError::Unhandled(v) => Self::Unhandled { - path, - msg: DisplayErrorContext(v).to_string(), - }, - err => Self::UnableToOpenFile { - path, - source: err.into(), - }, + err => classify_unhandled_error(path, err), }, }, UnableToHeadFile { path, source } => match source { @@ -178,7 +216,8 @@ impl From for super::Error { source: source.into(), } } else { - Self::UnableToOpenFile { + // who knows what happened here during dispatch, let's just tell the user it's transient + Self::MiscTransient { path, source: source.into(), } @@ -189,14 +228,7 @@ impl From for super::Error { path, source: no_such_key.into(), }, - HeadObjectError::Unhandled(v) => Self::Unhandled { - path, - msg: DisplayErrorContext(v).to_string(), - }, - err => Self::UnableToOpenFile { - path, - source: err.into(), - }, + err => classify_unhandled_error(path, err), }, }, UnableToListObjects { path, source } => match source { @@ -216,7 +248,8 @@ impl From for super::Error { source: source.into(), } } else { - Self::UnableToOpenFile { + // who knows what happened here during dispatch, let's just tell the user it's transient + Self::MiscTransient { path, source: source.into(), } @@ -227,14 +260,7 @@ impl From for super::Error { path, source: no_such_key.into(), }, - ListObjectsV2Error::Unhandled(v) => Self::Unhandled { - path, - msg: DisplayErrorContext(v).to_string(), - }, - err => Self::UnableToOpenFile { - path, - source: err.into(), - }, + err => classify_unhandled_error(path, err), }, }, InvalidUrl { path, source } => Self::InvalidUrl { path, source }, From a9fdd194bd5a8b2b4ae5a6e829a653661e48259a Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 25 Sep 2024 13:10:41 -0700 Subject: [PATCH 31/35] [BUG] Use dashes for machete dependency ignores (#2919) Co-authored-by: Colin Ho --- Cargo.toml | 2 +- src/arrow2/Cargo.toml | 2 +- src/daft-io/Cargo.toml | 2 +- src/daft-schema/Cargo.toml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6af21986e4..34c6a28c80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,7 +72,7 @@ publish = false version = "0.3.0-dev0" [package.metadata.cargo-machete] -ignored = ["lzma_sys"] +ignored = ["lzma-sys"] [patch.crates-io] arrow2 = {path = "src/arrow2"} diff --git a/src/arrow2/Cargo.toml b/src/arrow2/Cargo.toml index 0664947831..b3862e34f0 100644 --- a/src/arrow2/Cargo.toml +++ b/src/arrow2/Cargo.toml @@ -251,7 +251,7 @@ version = "0.17.4" allowlist = ["compute", "compute_sort", "compute_hash", "compute_nullif"] [package.metadata.cargo-machete] -ignored = ["arrow_array", "arrow_buffer", "avro_rs", "criterion", "crossbeam_channel", "flate2", "getrandom", "rustc_version", "sample_arrow2", "sample_std", "sample_test", "tokio", "tokio_util"] +ignored = ["arrow-array", "arrow-buffer", "avro-rs", "criterion", "crossbeam-channel", "flate2", "getrandom", "rustc_version", "sample-arrow2", "sample-std", "sample-test", "tokio", "tokio-util"] [target.wasm32-unknown-unknown.dependencies] getrandom = {version = "0.2", features = ["js"]} diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 433090370e..486a6810ae 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -63,4 +63,4 @@ name = "daft-io" version = {workspace = true} [package.metadata.cargo-machete] -ignored = ["openssl_sys"] +ignored = ["openssl-sys"] diff --git a/src/daft-schema/Cargo.toml b/src/daft-schema/Cargo.toml index ed6ecde2b7..fa189bff24 100644 --- a/src/daft-schema/Cargo.toml +++ b/src/daft-schema/Cargo.toml @@ -32,4 +32,4 @@ name = "daft-schema" version = {workspace = true} [package.metadata.cargo-machete] -ignored = ["num_traits"] # needed by num-derive +ignored = ["num-traits"] # needed by num-derive From b1ea3b9749e01512f48dfd45f9899a329fc9799f Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 25 Sep 2024 13:32:19 -0700 Subject: [PATCH 32/35] [CHORE] Enable sources to return empty tables (#2915) Sources in the native executor should still return empty tables if they didn't read anything. This unlocks a bunch of tests. --------- Co-authored-by: Colin Ho Co-authored-by: Cory Grinstead Co-authored-by: Colin Ho --- src/daft-local-execution/src/pipeline.rs | 4 +++- src/daft-local-execution/src/sources/in_memory.rs | 13 +++++++++---- src/daft-local-execution/src/sources/scan_task.rs | 7 +++++++ tests/cookbook/test_pandas_cookbook.py | 5 ----- tests/dataframe/test_repr.py | 6 ------ tests/dataframe/test_sort.py | 5 ----- 6 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 9f7da9b915..17146f216d 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -110,7 +110,9 @@ pub fn physical_plan_to_pipeline( } LocalPhysicalPlan::InMemoryScan(InMemoryScan { info, .. }) => { let partitions = psets.get(&info.cache_key).expect("Cache key not found"); - InMemorySource::new(partitions.clone()).boxed().into() + InMemorySource::new(partitions.clone(), info.source_schema.clone()) + .boxed() + .into() } LocalPhysicalPlan::Project(Project { input, projection, .. diff --git a/src/daft-local-execution/src/sources/in_memory.rs b/src/daft-local-execution/src/sources/in_memory.rs index 1212dd13cb..1bf08a8913 100644 --- a/src/daft-local-execution/src/sources/in_memory.rs +++ b/src/daft-local-execution/src/sources/in_memory.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use daft_core::prelude::SchemaRef; use daft_io::IOStatsRef; use daft_micropartition::MicroPartition; use tracing::instrument; @@ -9,11 +10,12 @@ use crate::{sources::source::SourceStream, ExecutionRuntimeHandle}; pub struct InMemorySource { data: Vec>, + schema: SchemaRef, } impl InMemorySource { - pub fn new(data: Vec>) -> Self { - Self { data } + pub fn new(data: Vec>, schema: SchemaRef) -> Self { + Self { data, schema } } pub fn boxed(self) -> Box { Box::new(self) as Box @@ -28,8 +30,11 @@ impl Source for InMemorySource { _runtime_handle: &mut ExecutionRuntimeHandle, _io_stats: IOStatsRef, ) -> crate::Result> { - let data = self.data.clone(); - Ok(Box::pin(futures::stream::iter(data))) + if self.data.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.schema.clone()))); + return Ok(Box::pin(futures::stream::once(async { empty }))); + } + Ok(Box::pin(futures::stream::iter(self.data.clone()))) } fn name(&self) -> &'static str { "InMemory" diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index f7374fa80a..7d36ba6a22 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -38,9 +38,16 @@ impl ScanTaskSource { maintain_order: bool, io_stats: IOStatsRef, ) -> DaftResult<()> { + let schema = scan_task.materialized_schema(); let mut stream = stream_scan_task(scan_task, Some(io_stats), maintain_order).await?; + let mut has_data = false; while let Some(partition) = stream.next().await { let _ = sender.send(partition?).await; + has_data = true; + } + if !has_data { + let empty = Arc::new(MicroPartition::empty(Some(schema.clone()))); + let _ = sender.send(empty).await; } Ok(()) } diff --git a/tests/cookbook/test_pandas_cookbook.py b/tests/cookbook/test_pandas_cookbook.py index 56838f7490..c852568f3f 100644 --- a/tests/cookbook/test_pandas_cookbook.py +++ b/tests/cookbook/test_pandas_cookbook.py @@ -7,15 +7,10 @@ import pytest import daft -from daft import context from daft.datatype import DataType from daft.expressions import col, lit from tests.conftest import assert_df_equals -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) ### # Idioms: if-then ### diff --git a/tests/dataframe/test_repr.py b/tests/dataframe/test_repr.py index d72082ca3b..8e04421901 100644 --- a/tests/dataframe/test_repr.py +++ b/tests/dataframe/test_repr.py @@ -8,14 +8,8 @@ from PIL import Image import daft -from daft import context from tests.utils import ANSI_ESCAPE, TD_STYLE, TH_STYLE -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - ROW_DIVIDER_REGEX = re.compile(r"╭─+┬*─*╮|├╌+┼*╌+┤") SHOWING_N_ROWS_REGEX = re.compile(r".*\(Showing first (\d+) of (\d+) rows\).*") UNMATERIALIZED_REGEX = re.compile(r".*\(No data to display: Dataframe not materialized\).*") diff --git a/tests/dataframe/test_sort.py b/tests/dataframe/test_sort.py index e972c13831..8a831a2bcf 100644 --- a/tests/dataframe/test_sort.py +++ b/tests/dataframe/test_sort.py @@ -5,14 +5,9 @@ import pyarrow as pa import pytest -from daft import context from daft.datatype import DataType from daft.errors import ExpressionTypeError -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) ### # Validation tests ### From 6460a0e464e6336a7a2b03fbbea92545ded8478b Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Wed, 25 Sep 2024 14:16:42 -0700 Subject: [PATCH 33/35] [CHORE] Fix issues from nightly tests (#2926) notebook checker test failure: https://github.com/Eventual-Inc/Daft/actions/runs/11040289879 property-based test failure: https://github.com/Eventual-Inc/Daft/actions/runs/11040292121 --- src/daft-dsl/src/resolve_expr.rs | 2 +- .../embeddings/daft_tutorial_embeddings_stackexchange.ipynb | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/daft-dsl/src/resolve_expr.rs b/src/daft-dsl/src/resolve_expr.rs index 7c8b1820c7..df686f60a0 100644 --- a/src/daft-dsl/src/resolve_expr.rs +++ b/src/daft-dsl/src/resolve_expr.rs @@ -388,7 +388,7 @@ pub fn resolve_aggexprs( pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> { let struct_expr_map = calculate_struct_expr_map(schema); - let names = if name.contains('*') { + let names = if name == "*" || name.ends_with(".*") { if let Ok(names) = get_wildcard_matches(name, schema, &struct_expr_map) { names } else { diff --git a/tutorials/embeddings/daft_tutorial_embeddings_stackexchange.ipynb b/tutorials/embeddings/daft_tutorial_embeddings_stackexchange.ipynb index 8ca22809c9..012cbdab4e 100644 --- a/tutorials/embeddings/daft_tutorial_embeddings_stackexchange.ipynb +++ b/tutorials/embeddings/daft_tutorial_embeddings_stackexchange.ipynb @@ -309,12 +309,12 @@ " similarity_search(\n", " df[\"embedding\"],\n", " top_embeddings=torch.stack(top_questions[\"embedding\"]),\n", - " top_urls=top_questions[\"URL\"],\n", + " top_urls=top_questions[\"url\"],\n", " ),\n", ")\n", "\n", "df = df.select(\n", - " df[\"URL\"],\n", + " df[\"url\"],\n", " df[\"question_score\"],\n", " df[\"search_result.related_top_question\"].alias(\"related_top_question\"),\n", " df[\"search_result.similarity\"].alias(\"similarity\"),\n", From 3f37f872bcf583134e202f373e68d1e894604d07 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Wed, 25 Sep 2024 14:59:25 -0700 Subject: [PATCH 34/35] [CHORE] update GH bug template (#2932) --- .github/ISSUE_TEMPLATE/bug_report.md | 81 +++++++++++++++------------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index dd84ea7824..d7b84ef2b0 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,38 +1,45 @@ ---- name: Bug report -about: Create a report to help us improve -title: '' -labels: '' -assignees: '' - ---- - -**Describe the bug** -A clear and concise description of what the bug is. - -**To Reproduce** -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** -A clear and concise description of what you expected to happen. - -**Screenshots** -If applicable, add screenshots to help explain your problem. - -**Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Browser [e.g. chrome, safari] - - Version [e.g. 22] - -**Smartphone (please complete the following information):** - - Device: [e.g. iPhone6] - - OS: [e.g. iOS8.1] - - Browser [e.g. stock browser, safari] - - Version [e.g. 22] - -**Additional context** -Add any other context about the problem here. +description: Create a report to help us improve Daft +labels: [bug, needs triage] +body: + - type: textarea + attributes: + label: Describe the bug + description: Describe the bug. + placeholder: > + A clear and concise description of what the bug is. + validations: + required: true + - type: textarea + attributes: + label: To Reproduce + placeholder: > + Steps to reproduce the behavior: + - type: textarea + attributes: + label: Expected behavior + placeholder: > + A clear and concise description of what you expected to happen. + - type: dropdown + id: component + attributes: + label: Component(s) + multiple: true + options: + - Expressions + - SQL + - Python Runner + - Ray Runner + - Parquet + - CSV + - Continuous Integration + - Developer Tools + - Documentation + - Other + validations: + required: true + - type: textarea + attributes: + label: Additional context + placeholder: > + Add any other context about the problem here. From 412dfbcdd332d26d27c622ddf67a30458bacff33 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Wed, 25 Sep 2024 15:03:50 -0700 Subject: [PATCH 35/35] [CHORE] update GH template name from md to yml (#2934) --- .github/ISSUE_TEMPLATE/bug_report.md | 45 --------------------------- .github/ISSUE_TEMPLATE/bug_report.yml | 45 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 45 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/bug_report.yml diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index d7b84ef2b0..0000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,45 +0,0 @@ -name: Bug report -description: Create a report to help us improve Daft -labels: [bug, needs triage] -body: - - type: textarea - attributes: - label: Describe the bug - description: Describe the bug. - placeholder: > - A clear and concise description of what the bug is. - validations: - required: true - - type: textarea - attributes: - label: To Reproduce - placeholder: > - Steps to reproduce the behavior: - - type: textarea - attributes: - label: Expected behavior - placeholder: > - A clear and concise description of what you expected to happen. - - type: dropdown - id: component - attributes: - label: Component(s) - multiple: true - options: - - Expressions - - SQL - - Python Runner - - Ray Runner - - Parquet - - CSV - - Continuous Integration - - Developer Tools - - Documentation - - Other - validations: - required: true - - type: textarea - attributes: - label: Additional context - placeholder: > - Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000000..b54b425d77 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,45 @@ +name: Bug report +description: Create a report to help us improve Daft +labels: [bug, needs triage] +body: +- type: textarea + attributes: + label: Describe the bug + description: Describe the bug. + placeholder: > + A clear and concise description of what the bug is. + validations: + required: true +- type: textarea + attributes: + label: To Reproduce + placeholder: > + Steps to reproduce the behavior: +- type: textarea + attributes: + label: Expected behavior + placeholder: > + A clear and concise description of what you expected to happen. +- type: dropdown + id: component + attributes: + label: Component(s) + multiple: true + options: + - Expressions + - SQL + - Python Runner + - Ray Runner + - Parquet + - CSV + - Continuous Integration + - Developer Tools + - Documentation + - Other + validations: + required: true +- type: textarea + attributes: + label: Additional context + placeholder: > + Add any other context about the problem here.