diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index ef7507bc66..c017f4cf5b 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1 +1,2 @@ d5e444d0a71409ae3701d4249ad877f1fb9e2235 # introduced `rustfmt.toml` and ran formatter; ignoring large formatting changes +45e2944e252ccdd563dc20edd9b29762e05cec1d # auto-fix prefer `Self` over explicit type diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 2b185084c2..dda04e253d 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -754,7 +754,7 @@ jobs: channel: stable - name: Install Machete - run: cargo +stable install cargo-machete + run: cargo +stable install cargo-machete@0.7.0 --locked - name: Run Machete run: cargo machete --with-metadata diff --git a/Cargo.lock b/Cargo.lock index 592a0793ba..db446053c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2084,6 +2084,7 @@ dependencies = [ "serde", "snafu", "test-log", + "uuid 1.10.0", ] [[package]] @@ -2170,6 +2171,7 @@ version = "0.3.0-dev0" dependencies = [ "common-daft-config", "common-error", + "common-io-config", "daft-core", "daft-dsl", "daft-functions", diff --git a/Cargo.toml b/Cargo.toml index 34c6a28c80..1d1065f026 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,8 +82,12 @@ parquet2 = {path = "src/parquet2"} debug = true [profile.dev] +debug = "line-tables-only" overflow-checks = false +[profile.dev.build-override] +opt-level = 3 + [profile.dev-bench] codegen-units = 16 debug = 1 # include symbols diff --git a/README.rst b/README.rst index d1eb6eb42a..cea9a29283 100644 --- a/README.rst +++ b/README.rst @@ -4,13 +4,13 @@ `Website `_ • `Docs `_ • `Installation`_ • `10-minute tour of Daft `_ • `Community and Support `_ -Daft: Distributed dataframes for multimodal data -======================================================= +Daft: Unified Engine for Data Analytics, Engineering & ML/AI +============================================================ -`Daft `_ is a distributed query engine for large-scale data processing in Python and is implemented in Rust. +`Daft `_ is a distributed query engine for large-scale data processing using Python or SQL, implemented in Rust. -* **Familiar interactive API:** Lazy Python Dataframe for rapid and interactive iteration +* **Familiar interactive API:** Lazy Python Dataframe for rapid and interactive iteration, or SQL for analytical queries * **Focus on the what:** Powerful Query Optimizer that rewrites queries to be as efficient as possible * **Data Catalog integrations:** Full integration with data catalogs such as Apache Iceberg * **Rich multimodal type-system:** Supports multimodal types such as Images, URLs, Tensors and more @@ -51,7 +51,7 @@ Quickstart In this example, we load images from an AWS S3 bucket's URLs and resize each image in the dataframe: -.. code:: +.. code:: python import daft diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index e2cb5e1eaa..c90817dfc2 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -9,7 +9,7 @@ from daft.io.scan import ScanOperator from daft.plan_scheduler.physical_plan_scheduler import PartitionT from daft.runners.partitioning import PartitionCacheEntry from daft.sql.sql_connection import SQLConnection -from daft.udf import PartialStatefulUDF, PartialStatelessUDF +from daft.udf import InitArgsType, PartialStatefulUDF, PartialStatelessUDF if TYPE_CHECKING: import pyarrow as pa @@ -1150,12 +1150,14 @@ def stateful_udf( expressions: list[PyExpr], return_dtype: PyDataType, resource_request: ResourceRequest | None, - init_args: tuple[tuple[Any, ...], dict[str, Any]] | None, + init_args: InitArgsType, batch_size: int | None, concurrency: int | None, ) -> PyExpr: ... def check_column_name_validity(name: str, schema: PySchema): ... -def extract_partial_stateful_udf_py(expression: PyExpr) -> dict[str, PartialStatefulUDF]: ... +def extract_partial_stateful_udf_py( + expression: PyExpr, +) -> dict[str, tuple[PartialStatefulUDF, InitArgsType]]: ... def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ... def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ... def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ... @@ -1195,8 +1197,21 @@ def minhash( ngram_size: int, seed: int = 1, ) -> PyExpr: ... + +# ----- +# SQL functions +# ----- +class SQLFunctionStub: + @property + def name(self) -> str: ... + @property + def docstring(self) -> str: ... + @property + def arg_names(self) -> list[str]: ... + def sql(sql: str, catalog: PyCatalog, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def sql_expr(sql: str) -> PyExpr: ... +def list_sql_functions() -> list[SQLFunctionStub]: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 91be85e02a..03c88b4779 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -241,6 +241,7 @@ def actor_pool_project( with get_context().runner().actor_pool_context( actor_pool_name, actor_resource_request, + task_resource_request, num_actors, projection, ) as actor_pool_id: diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 1ae7e90dac..2701aebc77 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -2,6 +2,7 @@ import math import os +import warnings from datetime import date, datetime, time from decimal import Decimal from typing import ( @@ -2936,6 +2937,21 @@ def count(self, mode: CountMode = CountMode.Valid) -> Expression: def lengths(self) -> Expression: """Gets the length of each list + (DEPRECATED) Please use Expression.list.length instead + + Returns: + Expression: a UInt64 expression which is the length of each list + """ + warnings.warn( + "This function will be deprecated from Daft version >= 0.3.5! Instead, please use 'Expression.list.length'", + category=DeprecationWarning, + ) + + return Expression._from_pyexpr(native.list_count(self._expr, CountMode.All)) + + def length(self) -> Expression: + """Gets the length of each list + Returns: Expression: a UInt64 expression which is the length of each list """ diff --git a/daft/iceberg/iceberg_write.py b/daft/iceberg/iceberg_write.py index 0de4c950d8..9fda932db7 100644 --- a/daft/iceberg/iceberg_write.py +++ b/daft/iceberg/iceberg_write.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Iterator, List, Tuple from daft import Expression, col +from daft.datatype import DataType +from daft.io.common import _get_schema_from_dict from daft.table import MicroPartition from daft.table.partitioning import PartitionedTable, partition_strings_to_path @@ -211,7 +213,10 @@ def visitor(self, partition_record: "IcebergRecord") -> "IcebergWriteVisitors.Fi return self.FileVisitor(self, partition_record) def to_metadata(self) -> MicroPartition: - return MicroPartition.from_pydict({"data_file": self.data_files}) + col_name = "data_file" + if len(self.data_files) == 0: + return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()})) + return MicroPartition.from_pydict({col_name: self.data_files}) def partitioned_table_to_iceberg_iter( diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index e80acb03cb..d861e2e060 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -138,10 +138,15 @@ def initialize_actor_global_state(uninitialized_projection: ExpressionsProjectio logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) - # TODO: Account for Stateful Actor initialization arguments as well as user-provided batch_size - PyActorPool.initialized_stateful_udfs_process_singleton = { - name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items() - } + PyActorPool.initialized_stateful_udfs_process_singleton = {} + for name, (partial_udf, init_args) in partial_stateful_udfs.items(): + if init_args is None: + PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls() + else: + args, kwargs = init_args + PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls( + *args, **kwargs + ) @staticmethod def build_partitions_with_stateful_project( @@ -332,20 +337,27 @@ def run_iter_tables( @contextlib.contextmanager def actor_pool_context( - self, name: str, resource_request: ResourceRequest, num_actors: int, projection: ExpressionsProjection + self, + name: str, + actor_resource_request: ResourceRequest, + _task_resource_request: ResourceRequest, + num_actors: int, + projection: ExpressionsProjection, ) -> Iterator[str]: actor_pool_id = f"py_actor_pool-{name}" - total_resource_request = resource_request * num_actors + total_resource_request = actor_resource_request * num_actors admitted = self._attempt_admit_task(total_resource_request) if not admitted: raise RuntimeError( - f"Not enough resources available to admit {num_actors} actors, each with resource request: {resource_request}" + f"Not enough resources available to admit {num_actors} actors, each with resource request: {actor_resource_request}" ) try: - self._actor_pools[actor_pool_id] = PyActorPool(actor_pool_id, num_actors, resource_request, projection) + self._actor_pools[actor_pool_id] = PyActorPool( + actor_pool_id, num_actors, actor_resource_request, projection + ) self._actor_pools[actor_pool_id].setup() logger.debug("Created actor pool %s with resources: %s", actor_pool_id, total_resource_request) yield actor_pool_id diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 7c84f7b9dc..c0579e11e5 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -931,9 +931,14 @@ def __init__(self, daft_execution_config: PyDaftExecutionConfig, uninitialized_p for name, psu in extract_partial_stateful_udf_py(expr._expr).items() } logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) - self.initialized_stateful_udfs = { - name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items() - } + + self.initialized_stateful_udfs = {} + for name, (partial_udf, init_args) in partial_stateful_udfs.items(): + if init_args is None: + self.initialized_stateful_udfs[name] = partial_udf.func_cls() + else: + args, kwargs = init_args + self.initialized_stateful_udfs[name] = partial_udf.func_cls(*args, **kwargs) @ray.method(num_returns=2) def run( @@ -981,8 +986,12 @@ def __init__( self._projection = projection def setup(self) -> None: + ray_options = _get_ray_task_options(self._resource_request_per_actor) + self._actors = [ - DaftRayActor.options(name=f"rank={rank}-{self._id}").remote(self._execution_config, self._projection) # type: ignore + DaftRayActor.options(name=f"rank={rank}-{self._id}", **ray_options).remote( # type: ignore + self._execution_config, self._projection + ) for rank in range(self._num_actors) ] @@ -1150,8 +1159,16 @@ def run_iter_tables( @contextlib.contextmanager def actor_pool_context( - self, name: str, resource_request: ResourceRequest, num_actors: PartID, projection: ExpressionsProjection + self, + name: str, + actor_resource_request: ResourceRequest, + task_resource_request: ResourceRequest, + num_actors: PartID, + projection: ExpressionsProjection, ) -> Iterator[str]: + # Ray runs actor methods serially, so the resource request for an actor should be both the actor's resources and the task's resources + resource_request = actor_resource_request + task_resource_request + execution_config = get_context().daft_execution_config if self.ray_client_mode: try: diff --git a/daft/runners/runner.py b/daft/runners/runner.py index c1dd30f64e..730f5e1a4a 100644 --- a/daft/runners/runner.py +++ b/daft/runners/runner.py @@ -67,7 +67,8 @@ def run_iter_tables( def actor_pool_context( self, name: str, - resource_request: ResourceRequest, + actor_resource_request: ResourceRequest, + task_resource_request: ResourceRequest, num_actors: int, projection: ExpressionsProjection, ) -> Iterator[str]: diff --git a/daft/series.py b/daft/series.py index 440d570b4a..15c5295b4c 100644 --- a/daft/series.py +++ b/daft/series.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import Any, Literal, TypeVar from daft.arrow_utils import ensure_array, ensure_chunked_array @@ -927,6 +928,14 @@ def iceberg_truncate(self, w: int) -> Series: class SeriesListNamespace(SeriesNamespace): def lengths(self) -> Series: + warnings.warn( + "This function will be deprecated from Daft version >= 0.3.5! Instead, please use 'length'", + category=DeprecationWarning, + ) + + return Series._from_pyseries(self._series.list_count(CountMode.All)) + + def length(self) -> Series: return Series._from_pyseries(self._series.list_count(CountMode.All)) def get(self, idx: Series, default: Series) -> Series: diff --git a/daft/sql/_sql_funcs.py b/daft/sql/_sql_funcs.py new file mode 100644 index 0000000000..030cd3b53f --- /dev/null +++ b/daft/sql/_sql_funcs.py @@ -0,0 +1,30 @@ +"""This module is used for Sphinx documentation only. We procedurally generate Python functions to allow +Sphinx to generate documentation pages for every SQL function. +""" + +from __future__ import annotations + +from inspect import Parameter as _Parameter +from inspect import Signature as _Signature + +from daft.daft import list_sql_functions as _list_sql_functions + + +def _create_sql_function(func_name: str, docstring: str, arg_names: list[str]): + def sql_function(*args, **kwargs): + raise NotImplementedError("This function is for documentation purposes only and should not be called.") + + sql_function.__name__ = func_name + sql_function.__qualname__ = func_name + sql_function.__doc__ = docstring + sql_function.__signature__ = _Signature([_Parameter(name, _Parameter.POSITIONAL_OR_KEYWORD) for name in arg_names]) # type: ignore[attr-defined] + + # Register the function in the current module + globals()[func_name] = sql_function + + +__all__ = [] + +for sql_function_stub in _list_sql_functions(): + _create_sql_function(sql_function_stub.name, sql_function_stub.docstring, sql_function_stub.arg_names) + __all__.append(sql_function_stub.name) diff --git a/daft/sql/sql.py b/daft/sql/sql.py index 987a9baeb0..2c9bb78554 100644 --- a/daft/sql/sql.py +++ b/daft/sql/sql.py @@ -1,7 +1,7 @@ # isort: dont-add-import: from __future__ import annotations import inspect -from typing import Optional, overload +from typing import Optional from daft.api_annotations import PublicAPI from daft.context import get_context @@ -38,22 +38,120 @@ def _copy_from(self, other: "SQLCatalog") -> None: @PublicAPI def sql_expr(sql: str) -> Expression: - return Expression._from_pyexpr(_sql_expr(sql)) - + """Parses a SQL string into a Daft Expression -@overload -def sql(sql: str) -> DataFrame: ... + This function allows you to create Daft Expressions from SQL snippets, which can then be used + in Daft operations or combined with other Daft Expressions. + Args: + sql (str): A SQL string to be parsed into a Daft Expression. -@overload -def sql(sql: str, catalog: SQLCatalog, register_globals: bool = ...) -> DataFrame: ... + Returns: + Expression: A Daft Expression representing the parsed SQL. + + Examples: + Create a simple SQL expression: + + >>> import daft + >>> expr = daft.sql_expr("1 + 2") + >>> print(expr) + lit(1) + lit(2) + + Use SQL expression in a Daft DataFrame operation: + + >>> df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> df = df.with_column("c", daft.sql_expr("a + b")) + >>> df.show() + ╭───────┬───────┬───────╮ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Int64 ┆ Int64 │ + ╞═══════╪═══════╪═══════╡ + │ 1 ┆ 4 ┆ 5 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 5 ┆ 7 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 6 ┆ 9 │ + ╰───────┴───────┴───────╯ + + (Showing first 3 of 3 rows) + + `daft.sql_expr` is also called automatically for you in some DataFrame operations such as filters: + + >>> df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> result = df.where("x < 3 AND y > 4") + >>> result.show() + ╭───────┬───────╮ + │ x ┆ y │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 2 ┆ 5 │ + ╰───────┴───────╯ + + (Showing first 1 of 1 rows) + """ + return Expression._from_pyexpr(_sql_expr(sql)) @PublicAPI def sql(sql: str, catalog: Optional[SQLCatalog] = None, register_globals: bool = True) -> DataFrame: - """Create a DataFrame from an SQL query. - - EXPERIMENTAL: This features is early in development and will change. + """Run a SQL query, returning the results as a DataFrame + + .. WARNING:: + This features is early in development and will likely experience API changes. + + Examples: + + A simple example joining 2 dataframes together using a SQL statement, relying on Daft to detect the names of + SQL tables using their corresponding Python variable names. + + >>> import daft + >>> + >>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": ["foo", "bar", "baz"]}) + >>> df2 = daft.from_pydict({"a": [1, 2, 3], "c": ["daft", None, None]}) + >>> + >>> # Daft automatically detects `df1` and `df2` from your Python global namespace + >>> result_df = daft.sql("SELECT * FROM df1 JOIN df2 ON df1.a = df2.a") + >>> result_df.show() + ╭───────┬──────┬──────╮ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Utf8 ┆ Utf8 │ + ╞═══════╪══════╪══════╡ + │ 1 ┆ foo ┆ daft │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 2 ┆ bar ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 3 ┆ baz ┆ None │ + ╰───────┴──────┴──────╯ + + (Showing first 3 of 3 rows) + + A more complex example using a SQLCatalog to create a named table called `"my_table"`, which can then be referenced from inside your SQL statement. + + >>> import daft + >>> from daft.sql import SQLCatalog + >>> + >>> df = daft.from_pydict({"a": [1, 2, 3], "b": ["foo", "bar", "baz"]}) + >>> + >>> # Register dataframes as tables in SQL explicitly with names + >>> catalog = SQLCatalog({"my_table": df}) + >>> + >>> daft.sql("SELECT a FROM my_table", catalog=catalog).show() + ╭───────╮ + │ a │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 1 │ + ├╌╌╌╌╌╌╌┤ + │ 2 │ + ├╌╌╌╌╌╌╌┤ + │ 3 │ + ╰───────╯ + + (Showing first 3 of 3 rows) Args: sql (str): SQL query to execute diff --git a/daft/table/table_io.py b/daft/table/table_io.py index ba07fab8a4..0f892534d9 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -22,6 +22,7 @@ PythonStorageConfig, StorageConfig, ) +from daft.datatype import DataType from daft.dependencies import pa, pacsv, pads, pajson, pq from daft.expressions import ExpressionsProjection, col from daft.filesystem import ( @@ -29,6 +30,7 @@ canonicalize_protocol, get_protocol_from_path, ) +from daft.io.common import _get_schema_from_dict from daft.logical.schema import Schema from daft.runners.partitioning import ( TableParseCSVOptions, @@ -426,16 +428,22 @@ def __call__(self, written_file): self.parent.paths.append(written_file.path) self.parent.partition_indices.append(self.idx) - def __init__(self, partition_values: MicroPartition | None, path_key: str = "path"): + def __init__(self, partition_values: MicroPartition | None, schema: Schema): self.paths: list[str] = [] self.partition_indices: list[int] = [] self.partition_values = partition_values - self.path_key = path_key + self.path_key = schema.column_names()[ + 0 + ] # I kept this from our original code, but idk why it's the first column name -kevin + self.schema = schema def visitor(self, partition_idx: int) -> TabularWriteVisitors.FileVisitor: return self.FileVisitor(self, partition_idx) def to_metadata(self) -> MicroPartition: + if len(self.paths) == 0: + return MicroPartition.empty(self.schema) + metadata: dict[str, Any] = {self.path_key: self.paths} if self.partition_values: @@ -488,10 +496,7 @@ def write_tabular( partitioned = PartitionedTable(table, partition_cols) - # I kept this from our original code, but idk why it's the first column name -kevin - path_key = schema.column_names()[0] - - visitors = TabularWriteVisitors(partitioned.partition_values(), path_key) + visitors = TabularWriteVisitors(partitioned.partition_values(), schema) for i, (part_table, part_path) in enumerate(partitioned_table_to_hive_iter(partitioned, resolved_path)): size_bytes = part_table.nbytes @@ -686,7 +691,10 @@ def visitor(self, partition_values: dict[str, str | None]) -> DeltaLakeWriteVisi return self.FileVisitor(self, partition_values) def to_metadata(self) -> MicroPartition: - return MicroPartition.from_pydict({"add_action": self.add_actions}) + col_name = "add_action" + if len(self.add_actions) == 0: + return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()})) + return MicroPartition.from_pydict({col_name: self.add_actions}) def write_deltalake( diff --git a/daft/udf.py b/daft/udf.py index e2afb495ff..c662dc6ced 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -4,7 +4,7 @@ import functools import inspect from abc import abstractmethod -from typing import Any, Callable, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union from daft.context import get_context from daft.daft import PyDataType, ResourceRequest @@ -13,6 +13,7 @@ from daft.expressions import Expression from daft.series import PySeries, Series +InitArgsType = Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]] @@ -294,7 +295,7 @@ class StatefulUDF(UDF): name: str cls: type return_dtype: DataType - init_args: tuple[tuple[Any, ...], dict[str, Any]] | None = None + init_args: InitArgsType = None concurrency: int | None = None def __post_init__(self): diff --git a/docs/source/10-min.ipynb b/docs/source/10-min.ipynb index cbda803752..d4444c2cd5 100644 --- a/docs/source/10-min.ipynb +++ b/docs/source/10-min.ipynb @@ -569,7 +569,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "See: [Expressions](user_guide/basic_concepts/expressions.rst)\n", + "See: [Expressions](user_guide/expressions.rst)\n", "\n", "Expressions are an API for defining computation that needs to happen over your columns.\n", "\n", @@ -1516,7 +1516,7 @@ "source": [ "### User-Defined Functions\n", "\n", - "See: [UDF User Guide](user_guide/daft_in_depth/udf)" + "See: [UDF User Guide](user_guide/udf)" ] }, { diff --git a/docs/source/_static/high_level_architecture.png b/docs/source/_static/high_level_architecture.png index 8e645c8899..f5133b2736 100644 Binary files a/docs/source/_static/high_level_architecture.png and b/docs/source/_static/high_level_architecture.png differ diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index ae34b7bb22..ec86e0bb5e 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -214,7 +214,7 @@ List :template: autosummary/accessor_method.rst Expression.list.join - Expression.list.lengths + Expression.list.length Expression.list.get Expression.list.slice Expression.list.chunk diff --git a/docs/source/api_docs/index.rst b/docs/source/api_docs/index.rst index 3079870df6..6bee44ad95 100644 --- a/docs/source/api_docs/index.rst +++ b/docs/source/api_docs/index.rst @@ -7,6 +7,7 @@ API Documentation Table of Contents creation dataframe + sql expressions schema datatype diff --git a/docs/source/api_docs/sql.rst b/docs/source/api_docs/sql.rst new file mode 100644 index 0000000000..33cf0c25dd --- /dev/null +++ b/docs/source/api_docs/sql.rst @@ -0,0 +1,15 @@ +SQL +=== + +.. autofunction:: daft.sql + +.. autofunction:: daft.sql_expr + +SQL Functions +------------- + +This is a full list of functions that can be used from within SQL. + + +.. sql-autosummary:: + :toctree: doc_gen/sql_funcs diff --git a/docs/source/conf.py b/docs/source/conf.py index 36e66be49a..fd59d32625 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,12 +9,16 @@ import inspect import os import subprocess +import sys import sphinx_autosummary_accessors # Set environment variable to help code determine whether or not we are running a Sphinx doc build process os.environ["DAFT_SPHINX_BUILD"] = "1" +# Help Sphinx find local custom extensions/directives that we build +sys.path.insert(0, os.path.abspath("ext")) + # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "Daft" @@ -45,10 +49,15 @@ "myst_nb", "sphinx_copybutton", "sphinx_autosummary_accessors", + "sphinx_tabs.tabs", + # Local extensions + "sql_autosummary", ] templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] +# Removes module names that prefix our classes +add_module_names = False # -- Options for Notebook rendering # https://myst-nb.readthedocs.io/en/latest/configuration.html?highlight=nb_execution_mode#execution @@ -86,6 +95,13 @@ "learn/user_guides/remote_cluster_execution": "distributed-computing.html", "learn/quickstart": "learn/10-min.html", "learn/10-min": "../10-min.html", + "user_guide/basic_concepts/expressions": "user_guide/expressions", + "user_guide/basic_concepts/dataframe_introduction": "user_guide/basic_concepts", + "user_guide/basic_concepts/introduction": "user_guide/basic_concepts", + "user_guide/daft_in_depth/aggregations": "user_guide/aggregations", + "user_guide/daft_in_depth/dataframe-operations": "user_guide/dataframe-operations", + "user_guide/daft_in_depth/datatypes": "user_guide/datatypes", + "user_guide/daft_in_depth/udf": "user_guide/udf", } # Resolving code links to github diff --git a/docs/source/ext/__init__.py b/docs/source/ext/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/source/ext/sql_autosummary.py b/docs/source/ext/sql_autosummary.py new file mode 100644 index 0000000000..5e37456cbe --- /dev/null +++ b/docs/source/ext/sql_autosummary.py @@ -0,0 +1,80 @@ +import inspect +import os + +from sphinx.ext.autosummary import Autosummary +from sphinx.util import logging + +logger = logging.getLogger(__name__) + + +TOCTREE = "doc_gen/sql_funcs" +SQL_MODULE_NAME = "daft.sql._sql_funcs" + +STUB_TEMPLATE = """ +.. currentmodule:: None + +.. autofunction:: {module_name}.{name} +""" + + +class SQLAutosummary(Autosummary): + def run(self): + func_names = get_sql_func_names() + # Run the normal autosummary stuff, override self.content + self.content = [f"~{SQL_MODULE_NAME}.{f}" for f in func_names] + nodes = super().run() + return nodes + + def get_sql_module_name(self): + return self.arguments[0] + + +def get_sql_func_names(): + # Import the SQL functions module + module = __import__(SQL_MODULE_NAME, fromlist=[""]) + + names = [] + for name, obj in inspect.getmembers(module): + if inspect.isfunction(obj) and not name.startswith("_"): + names.append(name) + + return names + + +def generate_stub(name: str): + """Generates a stub string for a SQL function""" + stub = name + "\n" + stub += "=" * len(name) + "\n\n" + stub += STUB_TEMPLATE.format(module_name=SQL_MODULE_NAME, name=name) + return stub + + +def generate_files(app): + # Determine where to write .rst files to + output_dir = os.path.join(app.srcdir, "api_docs", TOCTREE) + os.makedirs(output_dir, exist_ok=True) + + # Write stubfiles + func_names = get_sql_func_names() + for name in func_names: + stub_content = generate_stub(name) + filename = f"{SQL_MODULE_NAME}.{name}.rst" + filepath = os.path.join(output_dir, filename) + with open(filepath, "w") as f: + f.write(stub_content) + + # HACK: Not sure if this is ok? + app.env.found_docs.add(filepath) + + +def setup(app): + app.add_directive("sql-autosummary", SQLAutosummary) + + # Generate and register files when the builder is initialized + app.connect("builder-inited", generate_files) + + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/source/index.rst b/docs/source/index.rst index 3a2d3eabb5..6ee5c431b7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,14 +1,53 @@ Daft Documentation ================== -Daft is a distributed query engine for large-scale data processing in Python and is implemented in Rust. - -* **Familiar interactive API:** Lazy Python Dataframe for rapid and interactive iteration -* **Focus on the what:** Powerful Query Optimizer that rewrites queries to be as efficient as possible -* **Data Catalog integrations:** Full integration with data catalogs such as Apache Iceberg -* **Rich multimodal type-system:** Supports multimodal types such as Images, URLs, Tensors and more -* **Seamless Interchange**: Built on the `Apache Arrow `_ In-Memory Format -* **Built for the cloud:** `Record-setting `_ I/O performance for integrations with S3 cloud storage +Daft is a unified data engine for **data engineering, analytics and ML/AI**. + +Daft exposes both **SQL and Python DataFrame interfaces** as first-class citizens and is written in Rust. + +Daft provides a **snappy and delightful local interactive experience**, but also seamlessly **scales to petabyte-scale distributed workloads**. + +Use-Cases +--------- + +Data Engineering +**************** + +*Combine the performance of DuckDB, Pythonic UX of Polars and scalability of Apache Spark for data engineering from MB to PB scale* + +* Scale ETL workflows effortlessly from local to distributed environments +* Enjoy a Python-first experience without JVM dependency hell +* Leverage native integrations with cloud storage, open catalogs, and data formats + +Data Analytics +************** + +*Blend the snappiness of DuckDB with the scalability of Spark/Trino for unified local and distributed analytics* + +* Utilize complementary SQL and Python interfaces for versatile analytics +* Perform snappy local exploration with DuckDB-like performance +* Seamlessly scale to the cloud, outperforming distributed engines like Spark and Trino + +ML/AI +***** + +*Streamline ML/AI workflows with efficient dataloading from open formats like Parquet and JPEG* + +* Load data efficiently from open formats directly into PyTorch or NumPy +* Schedule large-scale model batch inference on distributed GPU clusters +* Optimize data curation with advanced clustering, deduplication, and filtering + +Technology +---------- + +Daft boasts strong integrations with technologies common across these workloads: + +* **Cloud Object Storage:** Record-setting I/O performance for integrations with S3 cloud storage, `battle-tested at exabyte-scale at Amazon `_ +* **ML/AI Python Ecosystem:** first-class integrations with `PyTorch `_ and `NumPy `_ for efficient interoperability with your ML/AI stack +* **Data Catalogs/Table Formats:** capabilities to effectively query table formats such as `Apache Iceberg `_, `Delta Lake `_ and `Apache Hudi `_ +* **Seamless Data Interchange:** zero-copy integration with `Apache Arrow `_ +* **Multimodal/ML Data:** native functionality for data modalities such as tensors, images, URLs, long-form text and embeddings + Installing Daft --------------- diff --git a/docs/source/migration_guides/coming_from_dask.rst b/docs/source/migration_guides/coming_from_dask.rst index 4e649ec8d3..99606c3ff9 100644 --- a/docs/source/migration_guides/coming_from_dask.rst +++ b/docs/source/migration_guides/coming_from_dask.rst @@ -30,7 +30,7 @@ Daft does not use an index Dask aims for as much feature-parity with pandas as possible, including maintaining the presence of an Index in the DataFrame. But keeping an Index is difficult when moving to a distributed computing environment. Dask doesn’t support row-based positional indexing (with .iloc) because it does not track the length of its partitions. It also does not support pandas MultiIndex. The argument for keeping the Index is that it makes some operations against the sorted index column very fast. In reality, resetting the Index forces a data shuffle and is an expensive operation. -Daft drops the need for an Index to make queries more readable and consistent. How you write a query should not change because of the state of an index or a reset_index call. In our opinion, eliminating the index makes things simpler, more explicit, more readable and therefore less error-prone. Daft achieves this by using the [Expressions API](../user_guide/basic_concepts/expressions). +Daft drops the need for an Index to make queries more readable and consistent. How you write a query should not change because of the state of an index or a reset_index call. In our opinion, eliminating the index makes things simpler, more explicit, more readable and therefore less error-prone. Daft achieves this by using the [Expressions API](../user_guide/expressions). In Dask you would index your DataFrame to return row ``b`` as follows: @@ -80,7 +80,7 @@ For example: res = ddf.map_partitions(my_function, **kwargs) -Daft implements two APIs for mapping computations over the data in your DataFrame in parallel: :doc:`Expressions <../user_guide/basic_concepts/expressions>` and :doc:`UDFs <../user_guide/daft_in_depth/udf>`. Expressions are most useful when you need to define computation over your columns. +Daft implements two APIs for mapping computations over the data in your DataFrame in parallel: :doc:`Expressions <../user_guide/expressions>` and :doc:`UDFs <../user_guide/udf>`. Expressions are most useful when you need to define computation over your columns. .. code:: python @@ -113,7 +113,7 @@ Daft is built as a DataFrame API for distributed Machine learning. You can use D Daft supports Multimodal Data Types ----------------------------------- -Dask supports the same data types as pandas. Daft is built to support many more data types, including Images, nested JSON, tensors, etc. See :doc:`the documentation <../user_guide/daft_in_depth/datatypes>` for a list of all supported data types. +Dask supports the same data types as pandas. Daft is built to support many more data types, including Images, nested JSON, tensors, etc. See :doc:`the documentation <../user_guide/datatypes>` for a list of all supported data types. Distributed Computing and Remote Clusters ----------------------------------------- diff --git a/docs/source/user_guide/daft_in_depth/aggregations.rst b/docs/source/user_guide/aggregations.rst similarity index 100% rename from docs/source/user_guide/daft_in_depth/aggregations.rst rename to docs/source/user_guide/aggregations.rst diff --git a/docs/source/user_guide/basic_concepts.rst b/docs/source/user_guide/basic_concepts.rst index 3bb3a89023..50fb8641cc 100644 --- a/docs/source/user_guide/basic_concepts.rst +++ b/docs/source/user_guide/basic_concepts.rst @@ -1,9 +1,407 @@ Basic Concepts ============== -.. toctree:: +Daft is a distributed data engine. The main abstraction in Daft is the :class:`DataFrame `, which conceptually can be thought of as a "table" of data with rows and columns. - basic_concepts/introduction - basic_concepts/dataframe_introduction - basic_concepts/expressions - basic_concepts/read-and-write +Daft also exposes a :doc:`sql` interface which interoperates closely with the DataFrame interface, allowing you to express data transformations and queries on your tables as SQL strings. + +.. image:: /_static/daft_illustration.png + :alt: Daft python dataframes make it easy to load any data such as PDF documents, images, protobufs, csv, parquet and audio files into a table dataframe structure for easy querying + :width: 500 + :align: center + +Terminology +----------- + +DataFrames +^^^^^^^^^^ + +The :class:`DataFrame ` is the core concept in Daft. Think of it as a table with rows and columns, similar to a spreadsheet or a database table. It's designed to handle large amounts of data efficiently. + +Daft DataFrames are lazy. This means that calling most methods on a DataFrame will not execute that operation immediately - instead, DataFrames expose explicit methods such as :meth:`daft.DataFrame.show` and :meth:`daft.DataFrame.write_parquet` +which will actually trigger computation of the DataFrame. + +Expressions +^^^^^^^^^^^ + +An :class:`Expression ` is a fundamental concept in Daft that allows you to define computations on DataFrame columns. They are the building blocks for transforming and manipulating data +within your DataFrame and will be your best friend if you are working with Daft primarily using the Python API. + +Query Plan +^^^^^^^^^^ + +As mentioned earlier, Daft DataFrames are lazy. Under the hood, each DataFrame in Daft is represented by a plan of operations that describes how to compute that DataFrame. + +This plan is called the "query plan" and calling methods on the DataFrame actually adds steps to the query plan! + +When your DataFrame is executed, Daft will read this plan, optimize it to make it run faster and then execute it to compute the requested results. + +Structured Query Language (SQL) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQL is a common query language for expressing queries over tables of data. Daft exposes a SQL API as an alternative (but often also complementary API) to the Python :class:`DataFrame ` and +:class:`Expression ` APIs for building queries. + +You can use SQL in Daft via the :func:`daft.sql` function, and Daft will also convert many SQL-compatible strings into Expressions via :func:`daft.sql_expr` for easy interoperability with DataFrames. + +DataFrame +--------- + +If you are coming from other DataFrame libraries such as Pandas or Polars, here are some key differences about Daft DataFrames: + +1. **Distributed:** When running in a distributed cluster, Daft splits your data into smaller "chunks" called *Partitions*. This allows Daft to process your data in parallel across multiple machines, leveraging more resources to work with large datasets. + +2. **Lazy:** When you write operations on a DataFrame, Daft doesn't execute them immediately. Instead, it creates a plan (called a query plan) of what needs to be done. This plan is optimized and only executed when you specifically request the results, which can lead to more efficient computations. + +3. **Multimodal:** Unlike traditional tables that usually contain simple data types like numbers and text, Daft DataFrames can handle complex data types in its columns. This includes things like images, audio files, or even custom Python objects. + +Common data operations that you would perform on DataFrames are: + +1. **Filtering rows:** Use :meth:`df.where(...) ` to keep only the rows that meet certain conditions. +2. **Creating new columns:** Use :meth:`df.with_column(...) ` to add a new column based on calculations from existing ones. +3. **Joining tables:** Use :meth:`df.join(other_df, ...) ` to combine two DataFrames based on common columns. +4. **Sorting:** Use :meth:`df.sort(...) ` to arrange your data based on values in one or more columns. +5. **Grouping and aggregating:** Use :meth:`df.groupby(...).agg(...) ` to summarize your data by groups. + +Creating a Dataframe +^^^^^^^^^^^^^^^^^^^^ + +.. seealso:: + + :doc:`read-and-write` - a more in-depth guide on various options for reading/writing data to/from Daft DataFrames from in-memory data (Python, Arrow), files (Parquet, CSV, JSON), SQL Databases and Data Catalogs + +Let's create our first Dataframe from a Python dictionary of columns. + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + import daft + + df = daft.from_pydict({ + "A": [1, 2, 3, 4], + "B": [1.5, 2.5, 3.5, 4.5], + "C": [True, True, False, False], + "D": [None, None, None, None], + }) + +Examine your Dataframe by printing it: + +.. code:: python + + df + +.. code-block:: text + :caption: Output + + ╭───────┬─────────┬─────────┬──────╮ + │ A ┆ B ┆ C ┆ D │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean ┆ Null │ + ╞═══════╪═════════╪═════════╪══════╡ + │ 1 ┆ 1.5 ┆ true ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 2 ┆ 2.5 ┆ true ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 3 ┆ 3.5 ┆ false ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 4 ┆ 4.5 ┆ false ┆ None │ + ╰───────┴─────────┴─────────┴──────╯ + + (Showing first 4 of 4 rows) + + +Congratulations - you just created your first DataFrame! It has 4 columns, "A", "B", "C", and "D". Let's try to select only the "A", "B", and "C" columns: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = df.select("A", "B", "C") + df + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql("SELECT A, B, C FROM df") + df + +.. code-block:: text + :caption: Output + + ╭───────┬─────────┬─────────╮ + │ A ┆ B ┆ C │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean │ + ╰───────┴─────────┴─────────╯ + + (No data to display: Dataframe not materialized) + + +But wait - why is it printing the message ``(No data to display: Dataframe not materialized)`` and where are the rows of each column? + +Executing our DataFrame and Viewing Data +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The reason that our DataFrame currently does not display its rows is that Daft DataFrames are **lazy**. This just means that Daft DataFrames will defer all its work until you tell it to execute. + +In this case, Daft is just deferring the work required to read the data and select columns, however in practice this laziness can be very useful for helping Daft optimize your queries before execution! + +.. NOTE:: + + When you call methods on a Daft Dataframe, it defers the work by adding to an internal "plan". You can examine the current plan of a DataFrame by calling :meth:`df.explain() `! + + Passing the ``show_all=True`` argument will show you the plan after Daft applies its query optimizations and the physical (lower-level) plan. + + .. code-block:: text + :caption: Plan Output + + == Unoptimized Logical Plan == + + * Project: col(A), col(B), col(C) + | + * Source: + | Number of partitions = 1 + | Output schema = A#Int64, B#Float64, C#Boolean, D#Null + + + == Optimized Logical Plan == + + * Project: col(A), col(B), col(C) + | + * Source: + | Number of partitions = 1 + | Output schema = A#Int64, B#Float64, C#Boolean, D#Null + + + == Physical Plan == + + * Project: col(A), col(B), col(C) + | Clustering spec = { Num partitions = 1 } + | + * InMemoryScan: + | Schema = A#Int64, B#Float64, C#Boolean, D#Null, + | Size bytes = 65, + | Clustering spec = { Num partitions = 1 } + +We can tell Daft to execute our DataFrame and store the results in-memory using :meth:`df.collect() `: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df.collect() + df + +.. code-block:: text + :caption: Output + + ╭───────┬─────────┬─────────┬──────╮ + │ A ┆ B ┆ C ┆ D │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean ┆ Null │ + ╞═══════╪═════════╪═════════╪══════╡ + │ 1 ┆ 1.5 ┆ true ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 2 ┆ 2.5 ┆ true ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 3 ┆ 3.5 ┆ false ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 4 ┆ 4.5 ┆ false ┆ None │ + ╰───────┴─────────┴─────────┴──────╯ + + (Showing first 4 of 4 rows) + +Now your DataFrame object ``df`` is **materialized** - Daft has executed all the steps required to compute the results, and has cached the results in memory so that it can display this preview. + +Any subsequent operations on ``df`` will avoid recomputations, and just use this materialized result! + +When should I materialize my DataFrame? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you "eagerly" call :meth:`df.collect() ` immediately on every DataFrame, you may run into issues: + +1. If data is too large at any step, materializing all of it may cause memory issues +2. Optimizations are not possible since we cannot "predict future operations" + +However, data science is all about experimentation and trying different things on the same data. This means that materialization is crucial when working interactively with DataFrames, since it speeds up all subsequent experimentation on that DataFrame. + +We suggest materializing DataFrames using :meth:`df.collect() ` when they contain expensive operations (e.g. sorts or expensive function calls) and have to be called multiple times by downstream code: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = df.sort("A") # expensive sort + df.collect() # materialize the DataFrame + + # All subsequent work on df avoids recomputing previous steps + df.sum("B").show() + df.mean("B").show() + df.with_column("try_this", df["A"] + 1).show(5) + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql("SELECT * FROM df ORDER BY A") + df.collect() + + # All subsequent work on df avoids recomputing previous steps + daft.sql("SELECT sum(B) FROM df").show() + daft.sql("SELECT mean(B) FROM df").show() + daft.sql("SELECT *, (A + 1) AS try_this FROM df").show(5) + +.. code-block:: text + :caption: Output + + ╭─────────╮ + │ B │ + │ --- │ + │ Float64 │ + ╞═════════╡ + │ 12 │ + ╰─────────╯ + + (Showing first 1 of 1 rows) + + ╭─────────╮ + │ B │ + │ --- │ + │ Float64 │ + ╞═════════╡ + │ 3 │ + ╰─────────╯ + + (Showing first 1 of 1 rows) + + ╭───────┬─────────┬─────────┬──────────╮ + │ A ┆ B ┆ C ┆ try_this │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean ┆ Int64 │ + ╞═══════╪═════════╪═════════╪══════════╡ + │ 1 ┆ 1.5 ┆ true ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ 2 ┆ 2.5 ┆ true ┆ 3 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ 3.5 ┆ false ┆ 4 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ 4 ┆ 4.5 ┆ false ┆ 5 │ + ╰───────┴─────────┴─────────┴──────────╯ + + (Showing first 4 of 4 rows) + + +In many other cases however, there are better options than materializing your entire DataFrame with :meth:`df.collect() `: + +1. **Peeking with df.show(N)**: If you only want to "peek" at the first few rows of your data for visualization purposes, you can use :meth:`df.show(N) `, which processes and shows only the first ``N`` rows. +2. **Writing to disk**: The ``df.write_*`` methods will process and write your data to disk per-partition, avoiding materializing it all in memory at once. +3. **Pruning data**: You can materialize your DataFrame after performing a :meth:`df.limit() `, :meth:`df.where() ` or :meth:`df.select() ` operation which processes your data or prune it down to a smaller size. + +Schemas and Types +^^^^^^^^^^^^^^^^^ + +Notice also that when we printed our DataFrame, Daft displayed its **schema**. Each column of your DataFrame has a **name** and a **type**, and all data in that column will adhere to that type! + +Daft can display your DataFrame's schema without materializing it. Under the hood, it performs intelligent sampling of your data to determine the appropriate schema, and if you make any modifications to your DataFrame it can infer the resulting types based on the operation. + +.. NOTE:: + + Under the hood, Daft represents data in the `Apache Arrow `_ format, which allows it to efficiently represent and work on data using high-performance kernels which are written in Rust. + + +Running Computation with Expressions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To run computations on data in our DataFrame, we use Expressions. + +The following statement will :meth:`df.show() ` a DataFrame that has only one column - the column ``A`` from our original DataFrame but with every row incremented by 1. + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df.select(df["A"] + 1).show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + daft.sql("SELECT A + 1 FROM df").show() + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ A │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 2 │ + ├╌╌╌╌╌╌╌┤ + │ 3 │ + ├╌╌╌╌╌╌╌┤ + │ 4 │ + ├╌╌╌╌╌╌╌┤ + │ 5 │ + ╰───────╯ + + (Showing first 4 of 4 rows) + +.. NOTE:: + + A common pattern is to create a new columns using ``DataFrame.with_column``: + + .. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + # Creates a new column named "foo" which takes on values + # of column "A" incremented by 1 + df = df.with_column("foo", df["A"] + 1) + df.show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + # Creates a new column named "foo" which takes on values + # of column "A" incremented by 1 + df = daft.sql("SELECT *, A + 1 AS foo FROM df") + df.show() + +.. code-block:: text + :caption: Output + + ╭───────┬─────────┬─────────┬───────╮ + │ A ┆ B ┆ C ┆ foo │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean ┆ Int64 │ + ╞═══════╪═════════╪═════════╪═══════╡ + │ 1 ┆ 1.5 ┆ true ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 2.5 ┆ true ┆ 3 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 3.5 ┆ false ┆ 4 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 4 ┆ 4.5 ┆ false ┆ 5 │ + ╰───────┴─────────┴─────────┴───────╯ + + (Showing first 4 of 4 rows) + +Congratulations, you have just written your first **Expression**: ``df["A"] + 1``! + +Expressions are a powerful way of describing computation on columns. For more details, check out the next section on :doc:`expressions` diff --git a/docs/source/user_guide/basic_concepts/dataframe_introduction.rst b/docs/source/user_guide/basic_concepts/dataframe_introduction.rst deleted file mode 100644 index 7e1075b34b..0000000000 --- a/docs/source/user_guide/basic_concepts/dataframe_introduction.rst +++ /dev/null @@ -1,203 +0,0 @@ -Dataframe -========= - -Data in Daft is represented as a DataFrame, which is a collection of data organized as a **table** with **rows** and **columns**. - -.. image:: /_static/daft_illustration.png - :alt: Daft python dataframes make it easy to load any data such as PDF documents, images, protobufs, csv, parquet and audio files into a table dataframe structure for easy querying - :width: 500 - :align: center - -This document provides an introduction to the Daft Dataframe. - -Creating a Dataframe --------------------- - -Let's create our first Dataframe from a Python dictionary of columns. - -.. code:: python - - import daft - - df = daft.from_pydict({ - "A": [1, 2, 3, 4], - "B": [1.5, 2.5, 3.5, 4.5], - "C": [True, True, False, False], - "D": [None, None, None, None], - }) - -Examine your Dataframe by printing it: - -.. code:: python - - df - -.. code:: none - - +---------+-----------+-----------+-----------+ - | A | B | C | D | - | Int64 | Float64 | Boolean | Null | - +=========+===========+===========+===========+ - | 1 | 1.5 | true | None | - +---------+-----------+-----------+-----------+ - | 2 | 2.5 | true | None | - +---------+-----------+-----------+-----------+ - | 3 | 3.5 | false | None | - +---------+-----------+-----------+-----------+ - | 4 | 4.5 | false | None | - +---------+-----------+-----------+-----------+ - (Showing first 4 of 4 rows) - - -Congratulations - you just created your first DataFrame! It has 4 columns, "A", "B", "C", and "D". Let's try to select only the "A", "B", and "C" columns: - -.. code:: python - - df.select("A", "B", "C") - -.. code:: none - - +---------+-----------+-----------+ - | A | B | C | - | Int64 | Float64 | Boolean | - +=========+===========+===========+ - +---------+-----------+-----------+ - (No data to display: Dataframe not materialized) - - -But wait - why is it printing the message ``(No data to display: Dataframe not materialized)`` and where are the rows of each column? - -Executing our DataFrame and Viewing Data ----------------------------------------- - -The reason that our DataFrame currently does not display its rows is that Daft DataFrames are **lazy**. This just means that Daft DataFrames will defer all its work until you tell it to execute. - -In this case, Daft is just deferring the work required to read the data and select columns, however in practice this laziness can be very useful for helping Daft optimize your queries before execution! - -.. NOTE:: - - When you call methods on a Daft Dataframe, it defers the work by adding to an internal "plan". You can examine the current plan of a DataFrame by calling :meth:`df.explain() `! - - Passing the ``show_all=True`` argument will show you the plan after Daft applies its query optimizations and the physical (lower-level) plan. - -We can tell Daft to execute our DataFrame and cache the results using :meth:`df.collect() `: - -.. code:: python - - df.collect() - df - -.. code:: none - - +---------+-----------+-----------+ - | A | B | C | - | Int64 | Float64 | Boolean | - +=========+===========+===========+ - | 1 | 1.5 | true | - +---------+-----------+-----------+ - | 2 | 2.5 | true | - +---------+-----------+-----------+ - | 3 | 3.5 | false | - +---------+-----------+-----------+ - | 4 | 4.5 | false | - +---------+-----------+-----------+ - (Showing first 4 of 4 rows) - -Now your DataFrame object ``df`` is **materialized** - Daft has executed all the steps required to compute the results, and has cached the results in memory so that it can display this preview. - -Any subsequent operations on ``df`` will avoid recomputations, and just use this materialized result! - -When should I materialize my DataFrame? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -If you "eagerly" call :meth:`df.collect() ` immediately on every DataFrame, you may run into issues: - -1. If data is too large at any step, materializing all of it may cause memory issues -2. Optimizations are not possible since we cannot "predict future operations" - -However, data science is all about experimentation and trying different things on the same data. This means that materialization is crucial when working interactively with DataFrames, since it speeds up all subsequent experimentation on that DataFrame. - -We suggest materializing DataFrames using :meth:`df.collect() ` when they contain expensive operations (e.g. sorts or expensive function calls) and have to be called multiple times by downstream code: - -.. code:: python - - df = df.with_column("A", df["A"].apply(expensive_function)) # expensive function - df = df.sort("A") # expensive sort - df.collect() # materialize the DataFrame - - # All subsequent work on df avoids recomputing previous steps - df.sum().show() - df.mean().show() - df.with_column("try_this", df["A"] + 1).show(5) - -In many other cases however, there are better options than materializing your entire DataFrame with :meth:`df.collect() `: - -1. **Peeking with df.show(N)**: If you only want to "peek" at the first few rows of your data for visualization purposes, you can use :meth:`df.show(N) `, which processes and shows only the first ``N`` rows. -2. **Writing to disk**: The ``df.write_*`` methods will process and write your data to disk per-partition, avoiding materializing it all in memory at once. -3. **Pruning data**: You can materialize your DataFrame after performing a :meth:`df.limit() `, :meth:`df.where() ` or :meth:`df.select() ` operation which processes your data or prune it down to a smaller size. - -Schemas and Types ------------------ - -Notice also that when we printed our DataFrame, Daft displayed its **schema**. Each column of your DataFrame has a **name** and a **type**, and all data in that column will adhere to that type! - -Daft can display your DataFrame's schema without materializing it. Under the hood, it performs intelligent sampling of your data to determine the appropriate schema, and if you make any modifications to your DataFrame it can infer the resulting types based on the operation. - -.. NOTE:: - - Under the hood, Daft represents data in the `Apache Arrow `_ format, which allows it to efficiently represent and work on data using high-performance kernels which are written in Rust. - - -Running Computations --------------------- - -To run computations on data in our DataFrame, we use Expressions. - -The following statement will :meth:`df.show() ` a DataFrame that has only one column - the column ``A`` from our original DataFrame but with every row incremented by 1. - -.. code:: python - - df.select(df["A"] + 1).show() - -.. code:: none - - +---------+ - | A | - | Int64 | - +=========+ - | 2 | - +---------+ - | 3 | - +---------+ - | 4 | - +---------+ - | 5 | - +---------+ - (Showing first 4 rows) - -.. NOTE:: - - A common pattern is to create a new columns using ``DataFrame.with_column``: - - .. code:: python - - # Creates a new column named "foo" which takes on values - # of column "A" incremented by 1 - df = df.with_column("foo", df["A"] + 1) - -Congratulations, you have just written your first **Expression**: ``df["A"] + 1``! - -Expressions -^^^^^^^^^^^ - -Expressions are how you define computations on your columns in Daft. - -The world of Daft contains much more than just numbers, and you can do much more than just add numbers together. Daft's rich Expressions API allows you to do things such as: - -1. Convert between different types with :meth:`df["numbers"].cast(float) ` -2. Download Bytes from a column containing String URLs using :meth:`df["urls"].url.download() ` -3. Run arbitrary Python functions on your data using :meth:`df["objects"].apply(my_python_function) ` - -We are also constantly looking to improve Daft and add more Expression functionality. Please contribute to the project with your ideas and code if you have an Expression in mind! - -The next section on :doc:`expressions` will provide a much deeper look at the Expressions that Daft provides. diff --git a/docs/source/user_guide/basic_concepts/expressions.rst b/docs/source/user_guide/basic_concepts/expressions.rst deleted file mode 100644 index db62ddb2fb..0000000000 --- a/docs/source/user_guide/basic_concepts/expressions.rst +++ /dev/null @@ -1,343 +0,0 @@ -Expressions -=========== - -Expressions are how you can express computations that should be run over columns of data. - -Creating Expressions --------------------- - -Referring to a column in a DataFrame -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Most commonly you will be creating expressions by referring to a column from an existing DataFrame. - -To do so, simply index a DataFrame with the string name of the column: - -.. code:: python - - import daft - - df = daft.from_pydict({"A": [1, 2, 3]}) - - # Refers to column "A" in `df` - df["A"] - -.. code:: none - - col(A) - -When we evaluate this ``df["A"]`` Expression, it will evaluate to the column from the ``df`` DataFrame with name "A"! - -Refer to a column with a certain name -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You may also find it necessary in certain situations to create an Expression with just the name of a column, without having an existing DataFrame to refer to. You can do this with the :func:`~daft.expressions.col` helper: - -.. code:: python - - from daft import col - - # Refers to a column named "A" - col("A") - -When this Expression is evaluated, it will resolve to "the column named A" in whatever evaluation context it is used within! - -Refer to multiple columns using a wildcard -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You can create expressions on multiple columns at once using a wildcard. The expression `col("*")` selects every column in a DataFrame, and you can operate on this expression in the same way as a single column: - -.. code:: python - - import daft - from daft import col - - df = daft.from_pydict({"A": [1, 2, 3], "B": [4, 5, 6]}) - df.select(col("*") * 3).show() - -.. code:: none - - ╭───────┬───────╮ - │ A ┆ B │ - │ --- ┆ --- │ - │ Int64 ┆ Int64 │ - ╞═══════╪═══════╡ - │ 3 ┆ 12 │ - ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ - │ 6 ┆ 15 │ - ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ - │ 9 ┆ 18 │ - ╰───────┴───────╯ - -Literals -^^^^^^^^ - -You may find yourself needing to hardcode a "single value" oftentimes as an expression. Daft provides a :func:`~daft.expressions.lit` helper to do so: - -.. code:: python - - from daft import lit - - # Refers to an expression which always evaluates to 42 - lit(42) - -This special :func:`~daft.expressions.lit` expression we just created evaluates always to the value ``42``. - -.. _userguide-numeric-expressions: - -Numeric Expressions -------------------- - -Since column "A" is an integer, we can run numeric computation such as addition, division and checking its value. Here are some examples where we create new columns using the results of such computations: - -.. code:: python - - # Add 1 to each element in column "A" - df = df.with_column("A_add_one", df["A"] + 1) - - # Divide each element in column A by 2 - df = df.with_column("A_divide_two", df["A"] / 2.) - - # Check if each element in column A is more than 1 - df = df.with_column("A_gt_1", df["A"] > 1) - - df.collect() - -.. code:: none - - +---------+-------------+----------------+-----------+ - | A | A_add_one | A_divide_two | A_gt_1 | - | Int64 | Int64 | Float64 | Boolean | - +=========+=============+================+===========+ - | 1 | 2 | 0.5 | false | - +---------+-------------+----------------+-----------+ - | 2 | 3 | 1 | true | - +---------+-------------+----------------+-----------+ - | 3 | 4 | 1.5 | true | - +---------+-------------+----------------+-----------+ - (Showing first 3 of 3 rows) - -Notice that the returned types of these operations are also well-typed according to their input types. For example, calling ``df["A"] > 1`` returns a column of type :meth:`Boolean `. - -Both the :meth:`Float ` and :meth:`Int ` types are numeric types, and inherit many of the same arithmetic Expression operations. You may find the full list of numeric operations in the :ref:`Expressions API reference `. - -.. _userguide-string-expressions: - -String Expressions ------------------- - -Daft also lets you have columns of strings in a DataFrame. Let's take a look! - -.. code:: python - - df = daft.from_pydict({"B": ["foo", "bar", "baz"]}) - df.show() - -.. code:: none - - +--------+ - | B | - | Utf8 | - +========+ - | foo | - +--------+ - | bar | - +--------+ - | baz | - +--------+ - (Showing first 3 rows) - -Unlike the numeric types, the string type does not support arithmetic operations such as ``*`` and ``/``. The one exception to this is the ``+`` operator, which is overridden to concatenate two string expressions as is commonly done in Python. Let's try that! - -.. code:: python - - df = df.with_column("B2", df["B"] + "foo") - df.show() - -.. code:: none - - +--------+--------+ - | B | B2 | - | Utf8 | Utf8 | - +========+========+ - | foo | foofoo | - +--------+--------+ - | bar | barfoo | - +--------+--------+ - | baz | bazfoo | - +--------+--------+ - (Showing first 3 rows) - -There are also many string operators that are accessed through a separate :meth:`.str.* ` "method namespace". - -For example, to check if each element in column "B" contains the substring "a", we can use the :meth:`.str.contains ` method: - -.. code:: python - - df = df.with_column("B2_contains_B", df["B2"].str.contains(df["B"])) - df.show() - -.. code:: none - - +--------+--------+-----------------+ - | B | B2 | B2_contains_B | - | Utf8 | Utf8 | Boolean | - +========+========+=================+ - | foo | foofoo | true | - +--------+--------+-----------------+ - | bar | barfoo | true | - +--------+--------+-----------------+ - | baz | bazfoo | true | - +--------+--------+-----------------+ - (Showing first 3 rows) - -You may find a full list of string operations in the :ref:`Expressions API reference `. - -URL Expressions -^^^^^^^^^^^^^^^ - -One special case of a String column you may find yourself working with is a column of URL strings. - -Daft provides the :meth:`.url.* ` method namespace with functionality for working with URL strings. For example, to download data from URLs: - -.. code:: python - - df = daft.from_pydict({ - "urls": [ - "https://www.google.com", - "s3://daft-public-data/open-images/validation-images/0001eeaf4aed83f9.jpg", - ], - }) - df = df.with_column("data", df["urls"].url.download()) - df.collect() - -.. code:: none - - +----------------------+----------------------+ - | urls | data | - | Utf8 | Binary | - +======================+======================+ - | https://www.google.c | b'`_ as the underlying executor, so you can find the full list of supported filters in the `jaq documentation `_. - -.. _userguide-logical-expressions: - -Logical Expressions -------------------- - -Logical Expressions are an expression that refers to a column of type :meth:`Boolean `, and can only take on the values True or False. - -.. code:: python - - df = daft.from_pydict({"C": [True, False, True]}) - df["C"] - -Daft supports logical operations such as ``&`` (and) and ``|`` (or) between logical expressions. - -Comparisons -^^^^^^^^^^^ - -Many of the types in Daft support comparisons between expressions that returns a Logical Expression. - -For example, here we can compare if each element in column "A" is equal to elements in column "B": - -.. code:: python - - df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 4]}) - - df = df.with_column("A_eq_B", df["A"] == df["B"]) - - df.collect() - -.. code:: none - - +---------+---------+-----------+ - | A | B | A_eq_B | - | Int64 | Int64 | Boolean | - +=========+=========+===========+ - | 1 | 1 | true | - +---------+---------+-----------+ - | 2 | 2 | true | - +---------+---------+-----------+ - | 3 | 4 | false | - +---------+---------+-----------+ - (Showing first 3 of 3 rows) - -Other useful comparisons can be found in the :ref:`Expressions API reference `. - -If Else Pattern -^^^^^^^^^^^^^^^ - -The :meth:`.if_else() ` method is a useful expression to have up your sleeve for choosing values between two other expressions based on a logical expression: - -.. code:: python - - df = daft.from_pydict({"A": [1, 2, 3], "B": [0, 2, 4]}) - - # Pick values from column A if the value in column A is bigger - # than the value in column B. Otherwise, pick values from column B. - df = df.with_column( - "A_if_bigger_else_B", - (df["A"] > df["B"]).if_else(df["A"], df["B"]), - ) - - df.collect() - -.. code:: none - - +---------+---------+----------------------+ - | A | B | A_if_bigger_else_B | - | Int64 | Int64 | Int64 | - +=========+=========+======================+ - | 1 | 0 | 1 | - +---------+---------+----------------------+ - | 2 | 2 | 2 | - +---------+---------+----------------------+ - | 3 | 4 | 4 | - +---------+---------+----------------------+ - (Showing first 3 of 3 rows) - -This is a useful expression for cleaning your data! diff --git a/docs/source/user_guide/basic_concepts/introduction.rst b/docs/source/user_guide/basic_concepts/introduction.rst deleted file mode 100644 index 2fa1c8fa94..0000000000 --- a/docs/source/user_guide/basic_concepts/introduction.rst +++ /dev/null @@ -1,92 +0,0 @@ -Introduction -============ - -Daft is a distributed query engine with a DataFrame API. The two key concepts to Daft are: - -1. :class:`DataFrame `: a Table-like structure that represents rows and columns of data -2. :class:`Expression `: a symbolic representation of computation that transforms columns of the DataFrame to a new one. - -With Daft, you create :class:`DataFrame ` from a variety of sources (e.g. reading data from files, data catalogs or from Python dictionaries) and use :class:`Expression ` to manipulate data in that DataFrame. Let's take a closer look at these two abstractions! - -DataFrame ---------- - -Conceptually, a DataFrame is a "table" of data, with rows and columns. - -.. image:: /_static/daft_illustration.png - :alt: Daft python dataframes make it easy to load any data such as PDF documents, images, protobufs, csv, parquet and audio files into a table dataframe structure for easy querying - :width: 500 - :align: center - -Using this abstraction of a DataFrame, you can run common tabular operations such as: - -1. Filtering rows: :meth:`df.where(...) ` -2. Creating new columns as a computation of existing columns: :meth:`df.with_column(...) ` -3. Joining two tables together: :meth:`df.join(...) ` -4. Sorting a table by the values in specified column(s): :meth:`df.sort(...) ` -5. Grouping and aggregations: :meth:`df.groupby(...).agg(...) ` - -Daft DataFrames are: - -1. **Distributed:** your data is split into *Partitions* and can be processed in parallel/on different machines -2. **Lazy:** computations are enqueued in a query plan which is then optimized and executed only when requested -3. **Multimodal:** columns can contain complex datatypes such as tensors, images and Python objects - -Since Daft is lazy, it can actually execute the query plan on a variety of different backends. By default, it will run computations locally using Python multithreading. However if you need to scale to large amounts of data that cannot be processed on a single machine, using the Ray runner allows Daft to run computations on a `Ray `_ cluster instead. - -Expressions ------------ - -The other important concept to understand when working with Daft are **expressions**. - -Because Daft is "lazy", it needs a way to represent computations that need to be performed on its data so that it can execute these computations at some later time. The answer to this is an :class:`~daft.expressions.Expression`! - -The simplest Expressions are: - -1. The column expression: :func:`col("a") ` which is used to refer to "some column named 'a'" -2. Or, if you already have an existing DataFrame ``df`` with a column named "a", you can refer to its column with Python's square bracket indexing syntax: ``df["a"]`` -3. The literal expression: :func:`lit(100) ` which represents a column that always takes on the provided value - -Daft then provides an extremely rich Expressions library to allow you to compose different computations that need to happen. For example: - -.. code:: python - - from daft import col, DataType - - # Take the column named "a" and add 1 to each element - col("a") + 1 - - # Take the column named "a", cast it to a string and check each element, returning True if it starts with "1" - col("a").cast(DataType.string()).str.startswith("1") - -Expressions are used in DataFrame operations, and the names of these Expressions are resolved to column names on the DataFrame that they are running on. Here is an example: - -.. code:: python - - import daft - - # Create a dataframe with a column "a" that has values [1, 2, 3] - df = daft.from_pydict({"a": [1, 2, 3]}) - - # Create new columns called "a_plus_1" and "a_startswith_1" using Expressions - df = df.select( - col("a"), - (col("a") + 1).alias("a_plus_1"), - col("a").cast(DataType.string()).str.startswith("1").alias("a_startswith_1"), - ) - - df.show() - -.. code:: none - - +---------+------------+------------------+ - | a | a_plus_1 | a_startswith_1 | - | Int64 | Int64 | Boolean | - +=========+============+==================+ - | 1 | 2 | true | - +---------+------------+------------------+ - | 2 | 3 | false | - +---------+------------+------------------+ - | 3 | 4 | false | - +---------+------------+------------------+ - (Showing first 3 rows) diff --git a/docs/source/user_guide/daft_in_depth.rst b/docs/source/user_guide/daft_in_depth.rst deleted file mode 100644 index 9b9702daca..0000000000 --- a/docs/source/user_guide/daft_in_depth.rst +++ /dev/null @@ -1,9 +0,0 @@ -Daft in Depth -============= - -.. toctree:: - - daft_in_depth/datatypes - daft_in_depth/dataframe-operations - daft_in_depth/aggregations - daft_in_depth/udf diff --git a/docs/source/user_guide/daft_in_depth/dataframe-operations.rst b/docs/source/user_guide/dataframe-operations.rst similarity index 100% rename from docs/source/user_guide/daft_in_depth/dataframe-operations.rst rename to docs/source/user_guide/dataframe-operations.rst diff --git a/docs/source/user_guide/daft_in_depth/datatypes.rst b/docs/source/user_guide/datatypes.rst similarity index 100% rename from docs/source/user_guide/daft_in_depth/datatypes.rst rename to docs/source/user_guide/datatypes.rst diff --git a/docs/source/user_guide/expressions.rst b/docs/source/user_guide/expressions.rst new file mode 100644 index 0000000000..54147a9401 --- /dev/null +++ b/docs/source/user_guide/expressions.rst @@ -0,0 +1,584 @@ +Expressions +=========== + +Expressions are how you can express computations that should be run over columns of data. + +Creating Expressions +^^^^^^^^^^^^^^^^^^^^ + +Referring to a column in a DataFrame +#################################### + +Most commonly you will be creating expressions by using the :func:`daft.col` function. + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + # Refers to column "A" + daft.col("A") + + .. group-tab:: ⚙️ SQL + + .. code:: python + + daft.sql_expr("A") + +.. code-block:: text + :caption: Output + + col(A) + +The above code creates an Expression that refers to a column named ``"A"``. + +Using SQL +######### + +Daft can also parse valid SQL as expressions. + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + daft.sql_expr("A + 1") + +.. code-block:: text + :caption: Output + + col(A) + lit(1) + +The above code will create an expression representing "the column named 'x' incremented by 1". For many APIs, sql_expr will actually be applied for you as syntactic sugar! + +Literals +######## + +You may find yourself needing to hardcode a "single value" oftentimes as an expression. Daft provides a :func:`~daft.expressions.lit` helper to do so: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + from daft import lit + + # Refers to an expression which always evaluates to 42 + lit(42) + + .. group-tab:: ⚙️ SQL + + .. code:: python + + # Refers to an expression which always evaluates to 42 + daft.sql_expr("42") + +.. code-block:: text + :caption: Output + + lit(42) + +This special :func:`~daft.expressions.lit` expression we just created evaluates always to the value ``42``. + +Wildcard Expressions +#################### + +You can create expressions on multiple columns at once using a wildcard. The expression `col("*")` selects every column in a DataFrame, and you can operate on this expression in the same way as a single column: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + import daft + from daft import col + + df = daft.from_pydict({"A": [1, 2, 3], "B": [4, 5, 6]}) + df.select(col("*") * 3).show() + +.. code-block:: text + :caption: Output + + ╭───────┬───────╮ + │ A ┆ B │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 3 ┆ 12 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 6 ┆ 15 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 9 ┆ 18 │ + ╰───────┴───────╯ + +Wildcards also work very well for accessing all members of a struct column: + + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + import daft + from daft import col + + df = daft.from_pydict({ + "person": [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35} + ] + }) + + # Access all fields of the 'person' struct + df.select(col("person.*")).show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + import daft + + df = daft.from_pydict({ + "person": [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35} + ] + }) + + # Access all fields of the 'person' struct using SQL + daft.sql("SELECT person.* FROM df").show() + +.. code-block:: text + :caption: Output + + ╭──────────┬───────╮ + │ name ┆ age │ + │ --- ┆ --- │ + │ String ┆ Int64 │ + ╞══════════╪═══════╡ + │ Alice ┆ 30 │ + ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ Bob ┆ 25 │ + ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ Charlie ┆ 35 │ + ╰──────────┴───────╯ + +In this example, we use the wildcard `*` to access all fields of the `person` struct column. This is equivalent to selecting each field individually (`person.name`, `person.age`), but is more concise and flexible, especially when dealing with structs that have many fields. + + + +Composing Expressions +^^^^^^^^^^^^^^^^^^^^^ + +.. _userguide-numeric-expressions: + +Numeric Expressions +################### + +Since column "A" is an integer, we can run numeric computation such as addition, division and checking its value. Here are some examples where we create new columns using the results of such computations: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + # Add 1 to each element in column "A" + df = df.with_column("A_add_one", df["A"] + 1) + + # Divide each element in column A by 2 + df = df.with_column("A_divide_two", df["A"] / 2.) + + # Check if each element in column A is more than 1 + df = df.with_column("A_gt_1", df["A"] > 1) + + df.collect() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql(""" + SELECT + *, + A + 1 AS A_add_one, + A / 2.0 AS A_divide_two, + A > 1 AS A_gt_1 + FROM df + """) + df.collect() + +.. code-block:: text + :caption: Output + + +---------+-------------+----------------+-----------+ + | A | A_add_one | A_divide_two | A_gt_1 | + | Int64 | Int64 | Float64 | Boolean | + +=========+=============+================+===========+ + | 1 | 2 | 0.5 | false | + +---------+-------------+----------------+-----------+ + | 2 | 3 | 1 | true | + +---------+-------------+----------------+-----------+ + | 3 | 4 | 1.5 | true | + +---------+-------------+----------------+-----------+ + (Showing first 3 of 3 rows) + +Notice that the returned types of these operations are also well-typed according to their input types. For example, calling ``df["A"] > 1`` returns a column of type :meth:`Boolean `. + +Both the :meth:`Float ` and :meth:`Int ` types are numeric types, and inherit many of the same arithmetic Expression operations. You may find the full list of numeric operations in the :ref:`Expressions API reference `. + +.. _userguide-string-expressions: + +String Expressions +################## + +Daft also lets you have columns of strings in a DataFrame. Let's take a look! + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"B": ["foo", "bar", "baz"]}) + df.show() + +.. code-block:: text + :caption: Output + + +--------+ + | B | + | Utf8 | + +========+ + | foo | + +--------+ + | bar | + +--------+ + | baz | + +--------+ + (Showing first 3 rows) + +Unlike the numeric types, the string type does not support arithmetic operations such as ``*`` and ``/``. The one exception to this is the ``+`` operator, which is overridden to concatenate two string expressions as is commonly done in Python. Let's try that! + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = df.with_column("B2", df["B"] + "foo") + df.show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql("SELECT *, B + 'foo' AS B2 FROM df") + df.show() + +.. code-block:: text + :caption: Output + + +--------+--------+ + | B | B2 | + | Utf8 | Utf8 | + +========+========+ + | foo | foofoo | + +--------+--------+ + | bar | barfoo | + +--------+--------+ + | baz | bazfoo | + +--------+--------+ + (Showing first 3 rows) + +There are also many string operators that are accessed through a separate :meth:`.str.* ` "method namespace". + +For example, to check if each element in column "B" contains the substring "a", we can use the :meth:`.str.contains ` method: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = df.with_column("B2_contains_B", df["B2"].str.contains(df["B"])) + df.show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql("SELECT *, contains(B2, B) AS B2_contains_B FROM df") + df.show() + +.. code-block:: text + :caption: Output + + +--------+--------+-----------------+ + | B | B2 | B2_contains_B | + | Utf8 | Utf8 | Boolean | + +========+========+=================+ + | foo | foofoo | true | + +--------+--------+-----------------+ + | bar | barfoo | true | + +--------+--------+-----------------+ + | baz | bazfoo | true | + +--------+--------+-----------------+ + (Showing first 3 rows) + +You may find a full list of string operations in the :ref:`Expressions API reference `. + +URL Expressions +############### + +One special case of a String column you may find yourself working with is a column of URL strings. + +Daft provides the :meth:`.url.* ` method namespace with functionality for working with URL strings. For example, to download data from URLs: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({ + "urls": [ + "https://www.google.com", + "s3://daft-public-data/open-images/validation-images/0001eeaf4aed83f9.jpg", + ], + }) + df = df.with_column("data", df["urls"].url.download()) + df.collect() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + + df = daft.from_pydict({ + "urls": [ + "https://www.google.com", + "s3://daft-public-data/open-images/validation-images/0001eeaf4aed83f9.jpg", + ], + }) + df = daft.sql(""" + SELECT + urls, + url_download(urls) AS data + FROM df + """) + df.collect() + +.. code-block:: text + :caption: Output + + +----------------------+----------------------+ + | urls | data | + | Utf8 | Binary | + +======================+======================+ + | https://www.google.c | b'`_ as the underlying executor, so you can find the full list of supported filters in the `jaq documentation `_. + +.. _userguide-logical-expressions: + +Logical Expressions +################### + +Logical Expressions are an expression that refers to a column of type :meth:`Boolean `, and can only take on the values True or False. + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"C": [True, False, True]}) + +Daft supports logical operations such as ``&`` (and) and ``|`` (or) between logical expressions. + +Comparisons +########### + +Many of the types in Daft support comparisons between expressions that returns a Logical Expression. + +For example, here we can compare if each element in column "A" is equal to elements in column "B": + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 4]}) + + df = df.with_column("A_eq_B", df["A"] == df["B"]) + + df.collect() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 4]}) + + df = daft.sql(""" + SELECT + A, + B, + A = B AS A_eq_B + FROM df + """) + + df.collect() + +.. code-block:: text + :caption: Output + + ╭───────┬───────┬─────────╮ + │ A ┆ B ┆ A_eq_B │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Int64 ┆ Boolean │ + ╞═══════╪═══════╪═════════╡ + │ 1 ┆ 1 ┆ true │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 2 ┆ 2 ┆ true │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ 4 ┆ false │ + ╰───────┴───────┴─────────╯ + + (Showing first 3 of 3 rows) + +Other useful comparisons can be found in the :ref:`Expressions API reference `. + +If Else Pattern +############### + +The :meth:`.if_else() ` method is a useful expression to have up your sleeve for choosing values between two other expressions based on a logical expression: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [0, 2, 4]}) + + # Pick values from column A if the value in column A is bigger + # than the value in column B. Otherwise, pick values from column B. + df = df.with_column( + "A_if_bigger_else_B", + (df["A"] > df["B"]).if_else(df["A"], df["B"]), + ) + + df.collect() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [0, 2, 4]}) + + df = daft.sql(""" + SELECT + A, + B, + CASE + WHEN A > B THEN A + ELSE B + END AS A_if_bigger_else_B + FROM df + """) + + df.collect() + +.. code-block:: text + :caption: Output + + ╭───────┬───────┬────────────────────╮ + │ A ┆ B ┆ A_if_bigger_else_B │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Int64 ┆ Int64 │ + ╞═══════╪═══════╪════════════════════╡ + │ 1 ┆ 0 ┆ 1 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ 2 ┆ 2 ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ 4 ┆ 4 │ + ╰───────┴───────┴────────────────────╯ + + (Showing first 3 of 3 rows) + +This is a useful expression for cleaning your data! diff --git a/docs/source/user_guide/fotw/fotw-001-images.ipynb b/docs/source/user_guide/fotw/fotw-001-images.ipynb index 827f98dd57..37d1f796d2 100644 --- a/docs/source/user_guide/fotw/fotw-001-images.ipynb +++ b/docs/source/user_guide/fotw/fotw-001-images.ipynb @@ -447,7 +447,7 @@ "metadata": {}, "source": [ "### Create Thumbnails\n", - "[Expressions](../basic_concepts/expressions) are a Daft API for defining computation that needs to happen over your columns. There are dedicated `image.(...)` Expressions for working with images.\n", + "[Expressions](../expressions) are a Daft API for defining computation that needs to happen over your columns. There are dedicated `image.(...)` Expressions for working with images.\n", "\n", "You can use the `image.resize` Expression to create a thumbnail of each image:" ] @@ -527,7 +527,7 @@ "\n", "We'll define a function that uses a pre-trained PyTorch model [ResNet50](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html) to classify the dog pictures. We'll then pass the `image` column to this PyTorch model and send the classification predictions to a new column `classify_breed`. \n", "\n", - "You will use Daft [User-Defined Functions (UDFs)](../daft_in_depth/udf) to do this. Daft UDFs which are the best way to run computations over multiple rows or columns.\n", + "You will use Daft [User-Defined Functions (UDFs)](../udf) to do this. Daft UDFs which are the best way to run computations over multiple rows or columns.\n", "\n", "#### Setting up PyTorch\n", "\n", diff --git a/docs/source/user_guide/index.rst b/docs/source/user_guide/index.rst index e79607a84d..b4b7150215 100644 --- a/docs/source/user_guide/index.rst +++ b/docs/source/user_guide/index.rst @@ -6,7 +6,13 @@ Daft User Guide :maxdepth: 1 basic_concepts - daft_in_depth + read-and-write + expressions + datatypes + dataframe-operations + sql + aggregations + udf poweruser integrations tutorials @@ -14,22 +20,7 @@ Daft User Guide Welcome to **Daft**! -Daft is a Python dataframe library that enables Pythonic data processing at large scale. - -* **Fast** - Daft kernels are written and accelerated using Rust on Apache Arrow arrays. - -* **Flexible** - you can work with any Python object in a Daft Dataframe. - -* **Interactive** - Daft provides a first-class notebook experience. - -* **Scalable** - Daft uses out-of-core algorithms to work with datasets that cannot fit in memory. - -* **Distributed** - Daft scales to a cluster of machines using Ray to crunch terabytes of data. - -* **Intelligent** - Daft performs query optimizations to speed up your work. - -The core interface provided by Daft is the *DataFrame*, which is a table of data consisting of rows and columns. This user guide -aims to help Daft users master the usage of the Daft *DataFrame* for all your data processing needs! +This user guide aims to help Daft users master the usage of the Daft for all your data needs. .. NOTE:: @@ -39,8 +30,7 @@ aims to help Daft users master the usage of the Daft *DataFrame* for all your da code you may wish to take a look at these resources: 1. :doc:`../10-min`: Itching to run some Daft code? Hit the ground running with our 10 minute quickstart notebook. - 2. (Coming soon!) Cheatsheet: Quick reference to commonly-used Daft APIs and usage patterns - useful to keep next to your laptop as you code! - 3. :doc:`../api_docs/index`: Searchable documentation and reference material to Daft's public Python API. + 2. :doc:`../api_docs/index`: Searchable documentation and reference material to Daft's public API. Table of Contents ----------------- @@ -52,11 +42,23 @@ The Daft User Guide is laid out as follows: High-level overview of Daft interfaces and usage to give you a better understanding of how Daft will fit into your day-to-day workflow. -:doc:`Daft in Depth ` -************************************ +Daft in Depth +************* Core Daft concepts all Daft users will find useful to understand deeply. +* :doc:`read-and-write` +* :doc:`expressions` +* :doc:`datatypes` +* :doc:`dataframe-operations` +* :doc:`aggregations` +* :doc:`udf` + +:doc:`Structured Query Language (SQL) ` +******************************************** + +A look into Daft's SQL interface and how it complements Daft's Pythonic DataFrame APIs. + :doc:`The Daft Poweruser ` ************************************* diff --git a/docs/source/user_guide/basic_concepts/read-and-write.rst b/docs/source/user_guide/read-and-write.rst similarity index 64% rename from docs/source/user_guide/basic_concepts/read-and-write.rst rename to docs/source/user_guide/read-and-write.rst index 1d1a481fea..f8585111d9 100644 --- a/docs/source/user_guide/basic_concepts/read-and-write.rst +++ b/docs/source/user_guide/read-and-write.rst @@ -1,5 +1,5 @@ -Reading/Writing -=============== +Reading/Writing Data +==================== Daft can read data from a variety of sources, and write data to many destinations. @@ -37,7 +37,7 @@ To learn more about each of these constructors, as well as the options that they From Data Catalogs ^^^^^^^^^^^^^^^^^^ -If you use catalogs such as Apache Iceberg or Hive, you may wish to consult our user guide on integrations with Data Catalogs: :doc:`Daft integration with Data Catalogs <../integrations/>`. +If you use catalogs such as Apache Iceberg or Hive, you may wish to consult our user guide on integrations with Data Catalogs: :doc:`Daft integration with Data Catalogs `. From File Paths ^^^^^^^^^^^^^^^ @@ -87,7 +87,50 @@ In order to partition the data, you can specify a partition column, which will a # Read with a partition column df = daft.read_sql("SELECT * FROM my_table", partition_col="date", uri) -To learn more, consult the :doc:`SQL User Guide <../integrations/sql>` or the API documentation on :func:`daft.read_sql`. +To learn more, consult the :doc:`SQL User Guide ` or the API documentation on :func:`daft.read_sql`. + + +Reading a column of URLs +------------------------ + +Daft provides a convenient way to read data from a column of URLs using the :meth:`.url.download() ` method. This is particularly useful when you have a DataFrame with a column containing URLs pointing to external resources that you want to fetch and incorporate into your DataFrame. + +Here's an example of how to use this feature: + +.. code:: python + + # Assume we have a DataFrame with a column named 'image_urls' + df = daft.from_pydict({ + "image_urls": [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + "https://example.com/image3.jpg" + ] + }) + + # Download the content from the URLs and create a new column 'image_data' + df = df.with_column("image_data", df["image_urls"].url.download()) + df.show() + +.. code-block:: text + :caption: Output + + +------------------------------------+------------------------------------+ + | image_urls | image_data | + | Utf8 | Binary | + +====================================+====================================+ + | https://example.com/image1.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | + +------------------------------------+------------------------------------+ + | https://example.com/image2.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | + +------------------------------------+------------------------------------+ + | https://example.com/image3.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | + +------------------------------------+------------------------------------+ + + (Showing first 3 of 3 rows) + + +This approach allows you to efficiently download and process data from a large number of URLs in parallel, leveraging Daft's distributed computing capabilities. + Writing Data diff --git a/docs/source/user_guide/sql.rst b/docs/source/user_guide/sql.rst new file mode 100644 index 0000000000..fec2761e05 --- /dev/null +++ b/docs/source/user_guide/sql.rst @@ -0,0 +1,244 @@ +SQL +=== + +Daft supports Structured Query Language (SQL) as a way of constructing query plans (represented in Python as a :class:`daft.DataFrame`) and expressions (:class:`daft.Expression`). + +SQL is a human-readable way of constructing these query plans, and can often be more ergonomic than using DataFrames for writing queries. + +.. NOTE:: + Daft's SQL support is new and is constantly being improved on! Please give us feedback and we'd love to hear more about what you would like. + +Running SQL on DataFrames +------------------------- + +Daft's :func:`daft.sql` function will automatically detect any :class:`daft.DataFrame` objects in your current Python environment to let you query them easily by name. + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + # Note the variable name `my_special_df` + my_special_df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + + # Use the SQL table name "my_special_df" to refer to the above DataFrame! + sql_df = daft.sql("SELECT A, B FROM my_special_df") + + sql_df.show() + +.. code-block:: text + :caption: Output + + ╭───────┬───────╮ + │ A ┆ B │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 1 ┆ 1 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 3 │ + ╰───────┴───────╯ + + (Showing first 3 of 3 rows) + +In the above example, we query the DataFrame called `"my_special_df"` by simply referring to it in the SQL command. This produces a new DataFrame `sql_df` which can +natively integrate with the rest of your Daft query. + +Reading data from SQL +--------------------- + +.. WARNING:: + + This feature is a WIP and will be coming soon! We will support reading common datasources directly from SQL: + + .. code-block:: python + + daft.sql("SELECT * FROM read_parquet('s3://...')") + daft.sql("SELECT * FROM read_delta_lake('s3://...')") + + Today, a workaround for this is to construct your dataframe in Python first and use it from SQL instead: + + .. code-block:: python + + df = daft.read_parquet("s3://...") + daft.sql("SELECT * FROM df") + + We appreciate your patience with us and hope to deliver this crucial feature soon! + +SQL Expressions +--------------- + +SQL has the concept of expressions as well. Here is an example of a simple addition expression, adding columns "a" and "b" in SQL to produce a new column C. + +We also present here the equivalent query for SQL and DataFrame. Notice how similar the concepts are! + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + df = daft.sql("SELECT A + B as C FROM df") + df.show() + + .. group-tab:: 🐍 Python + + .. code:: python + + expr = (daft.col("A") + daft.col("B")).alias("C") + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + df = df.select(expr) + df.show() + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ C │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 2 │ + ├╌╌╌╌╌╌╌┤ + │ 4 │ + ├╌╌╌╌╌╌╌┤ + │ 6 │ + ╰───────╯ + + (Showing first 3 of 3 rows) + +In the above query, both the SQL version of the query and the DataFrame version of the query produce the same result. + +Under the hood, they run the same Expression ``col("A") + col("B")``! + +One really cool trick you can do is to use the :func:`daft.sql_expr` function as a helper to easily create Expressions. The following are equivalent: + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + sql_expr = daft.sql_expr("A + B as C") + print("SQL expression:", sql_expr) + + .. group-tab:: 🐍 Python + + .. code:: python + + py_expr = (daft.col("A") + daft.col("B")).alias("C") + print("Python expression:", py_expr) + + +.. code-block:: text + :caption: Output + + SQL expression: col(A) + col(B) as C + Python expression: col(A) + col(B) as C + +This means that you can pretty much use SQL anywhere you use Python expressions, making Daft extremely versatile at mixing workflows which leverage both SQL and Python. + +As an example, consider the filter query below and compare the two equivalent Python and SQL queries: + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + + # Daft automatically converts this string using `daft.sql_expr` + df = df.where("A < 2") + + df.show() + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + + # Using Daft's Python Expression API + df = df.where(df["A"] < 2) + + df.show() + +.. code-block:: text + :caption: Output + + ╭───────┬───────╮ + │ A ┆ B │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 1 ┆ 1 │ + ╰───────┴───────╯ + + (Showing first 1 of 1 rows) + +Pretty sweet! Of course, this support for running Expressions on your columns extends well beyond arithmetic as we'll see in the next section on SQL Functions. + +SQL Functions +------------- + +SQL also has access to all of Daft's powerful :class:`daft.Expression` functionality through SQL functions. + +However, unlike the Python Expression API which encourages method-chaining (e.g. ``col("a").url.download().image.decode()``), in SQL you have to do function nesting instead (e.g. ``"image_decode(url_download(a))""``). + +.. NOTE:: + + A full catalog of the available SQL Functions in Daft is available in the :doc:`../api_docs/sql`. + + Note that it closely mirrors the Python API, with some function naming differences vs the available Python methods. + We also have some aliased functions for ANSI SQL-compliance or familiarity to users coming from other common SQL dialects such as PostgreSQL and SparkSQL to easily find their functionality. + +Here is an example of an equivalent function call in SQL vs Python: + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"urls": [ + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + ]}) + df = daft.sql("SELECT image_decode(url_download(urls)) FROM df") + df.show() + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"urls": [ + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + ]}) + df = df.select(daft.col("urls").url.download().image.decode()) + df.show() + +.. code-block:: text + :caption: Output + + ╭──────────────╮ + │ urls │ + │ --- │ + │ Image[MIXED] │ + ╞══════════════╡ + │ │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ │ + ╰──────────────╯ + + (Showing first 3 of 3 rows) diff --git a/docs/source/user_guide/daft_in_depth/udf.rst b/docs/source/user_guide/udf.rst similarity index 100% rename from docs/source/user_guide/daft_in_depth/udf.rst rename to docs/source/user_guide/udf.rst diff --git a/requirements-dev.txt b/requirements-dev.txt index a67574df90..9c7809ac80 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -88,3 +88,4 @@ sphinx-book-theme==1.1.0; python_version >= "3.9" sphinx-reredirects>=0.1.1 sphinx-copybutton>=0.5.2 sphinx-autosummary-accessors==2023.4.0; python_version >= "3.9" +sphinx-tabs==3.4.5 diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 44bb95c1b0..4da0140e01 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -27,13 +27,21 @@ impl PyDaftPlanningConfig { } } - fn with_config_values(&mut self, default_io_config: Option) -> PyResult { + fn with_config_values( + &mut self, + default_io_config: Option, + enable_actor_pool_projections: Option, + ) -> PyResult { let mut config = self.config.as_ref().clone(); if let Some(default_io_config) = default_io_config { config.default_io_config = default_io_config.config; } + if let Some(enable_actor_pool_projections) = enable_actor_pool_projections { + config.enable_actor_pool_projections = enable_actor_pool_projections; + } + Ok(Self { config: Arc::new(config), }) diff --git a/src/common/file-formats/src/file_format_config.rs b/src/common/file-formats/src/file_format_config.rs index fe659bc444..6054907861 100644 --- a/src/common/file-formats/src/file_format_config.rs +++ b/src/common/file-formats/src/file_format_config.rs @@ -115,6 +115,17 @@ impl ParquetSourceConfig { } } +impl Default for ParquetSourceConfig { + fn default() -> Self { + Self { + coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, + field_id_mapping: None, + row_groups: None, + chunk_size: None, + } + } +} + #[cfg(feature = "python")] #[pymethods] impl ParquetSourceConfig { diff --git a/src/daft-core/src/array/ops/arithmetic.rs b/src/daft-core/src/array/ops/arithmetic.rs index aa6b067f78..21e23657c6 100644 --- a/src/daft-core/src/array/ops/arithmetic.rs +++ b/src/daft-core/src/array/ops/arithmetic.rs @@ -59,9 +59,7 @@ where let opt_lhs = lhs.get(0); match opt_lhs { None => Ok(DataArray::full_null(rhs.name(), lhs.data_type(), rhs.len())), - // NOTE: naming logic here is wrong, and assigns the rhs name. However, renaming is handled at the Series level so this - // error is obfuscated. - Some(lhs) => rhs.apply(|rhs| operation(lhs, rhs)), + Some(scalar) => Ok(rhs.apply(|rhs| operation(scalar, rhs))?.rename(lhs.name())), } } (a, b) => Err(DaftError::ValueError(format!( diff --git a/src/daft-core/src/array/ops/as_arrow.rs b/src/daft-core/src/array/ops/as_arrow.rs index 51ba52dd2c..c2315d39cd 100644 --- a/src/daft-core/src/array/ops/as_arrow.rs +++ b/src/daft-core/src/array/ops/as_arrow.rs @@ -15,7 +15,9 @@ use crate::{ pub trait AsArrow { type Output; - // Retrieve the underlying concrete Arrow2 array. + /// This does not correct for the logical types and will just yield the physical type of the array. + /// For example, a TimestampArray will yield an arrow Int64Array rather than a arrow Timestamp Array. + /// To get a corrected arrow type, see `.to_arrow()`. fn as_arrow(&self) -> &Self::Output; } diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index ebac895e20..fed18367e1 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -342,7 +342,7 @@ pub enum PadPlacement { Right, } -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)] pub struct Utf8NormalizeOptions { pub remove_punct: bool, pub lowercase: bool, diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index 9c05eb0b02..ab80f4eac4 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -138,27 +138,27 @@ impl<'a> Add for InferDataType<'a> { (du_self @ &DataType::Duration(..), du_other @ &DataType::Duration(..)) => Err(DaftError::TypeError( format!("Cannot add due to differing precision: {}, {}. Please explicitly cast to the precision you wish to add in.", du_self, du_other) )), - (DataType::Null, other) | (other, DataType::Null) => { + (dtype @ DataType::Null, other) | (other, dtype @ DataType::Null) => { match other { // Condition is for backwards compatibility. TODO: remove DataType::Binary | DataType::FixedSizeBinary(..) | DataType::Date => Err(DaftError::TypeError( - format!("Cannot add types: {}, {}", self, other) + format!("Cannot add types: {}, {}", dtype, other) )), other if other.is_physical() => Ok(other.clone()), _ => Err(DaftError::TypeError( - format!("Cannot add types: {}, {}", self, other) + format!("Cannot add types: {}, {}", dtype, other) )), } } - (DataType::Utf8, other) | (other, DataType::Utf8) => { + (dtype @ DataType::Utf8, other) | (other, dtype @ DataType::Utf8) => { match other { // DataType::Date condition is for backwards compatibility. TODO: remove DataType::Binary | DataType::FixedSizeBinary(..) | DataType::Date => Err(DaftError::TypeError( - format!("Cannot add types: {}, {}", self, other) + format!("Cannot add types: {}, {}", dtype, other) )), other if other.is_physical() => Ok(DataType::Utf8), _ => Err(DaftError::TypeError( - format!("Cannot add types: {}, {}", self, other) + format!("Cannot add types: {}, {}", dtype, other) )), } } diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs deleted file mode 100644 index 16e1a80b8d..0000000000 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ /dev/null @@ -1,404 +0,0 @@ -use std::ops::{Add, Div, Mul, Rem, Sub}; - -use common_error::DaftResult; - -use super::{ArrayWrapper, IntoSeries, Series}; -use crate::{ - array::{ - ops::{DaftCompare, DaftLogical}, - FixedSizeListArray, ListArray, StructArray, - }, - datatypes::{ - logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, MapArray, - SparseTensorArray, TensorArray, TimeArray, TimestampArray, - }, - BinaryArray, BooleanArray, DataType, ExtensionArray, Field, FixedSizeBinaryArray, - Float32Array, Float64Array, InferDataType, Int128Array, Int16Array, Int32Array, Int64Array, - Int8Array, NullArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, - }, - series::series_like::SeriesLike, - with_match_comparable_daft_types, with_match_integer_daft_types, with_match_numeric_daft_types, -}; -#[cfg(feature = "python")] -use crate::{datatypes::PythonArray, series::ops::py_binary_op_utilfn}; - -#[cfg(feature = "python")] -macro_rules! py_binary_op { - ($lhs:expr, $rhs:expr, $pyoperator:expr) => { - py_binary_op_utilfn!($lhs, $rhs, $pyoperator, "map_operator_arrow_semantics") - }; -} -#[cfg(feature = "python")] -macro_rules! py_binary_op_bool { - ($lhs:expr, $rhs:expr, $pyoperator:expr) => { - py_binary_op_utilfn!($lhs, $rhs, $pyoperator, "map_operator_arrow_semantics_bool") - }; -} - -macro_rules! cast_downcast_op { - ($lhs:expr, $rhs:expr, $ty_expr:expr, $ty_type:ty, $op:ident) => {{ - let lhs = $lhs.cast($ty_expr)?; - let rhs = $rhs.cast($ty_expr)?; - let lhs = lhs.downcast::<$ty_type>()?; - let rhs = rhs.downcast::<$ty_type>()?; - lhs.$op(rhs) - }}; -} - -macro_rules! cast_downcast_op_into_series { - ($lhs:expr, $rhs:expr, $ty_expr:expr, $ty_type:ty, $op:ident) => {{ - Ok(cast_downcast_op!($lhs, $rhs, $ty_expr, $ty_type, $op)? - .into_series() - .rename($lhs.name())) - }}; -} - -macro_rules! apply_fixed_numeric_op { - ($lhs:expr, $rhs:expr, $op:ident) => {{ - $lhs.$op($rhs)? - }}; -} - -macro_rules! fixed_sized_numeric_binary_op { - ($left:expr, $right:expr, $output_type:expr, $op:ident) => {{ - assert!($left.data_type().is_fixed_size_numeric()); - assert!($right.data_type().is_fixed_size_numeric()); - - match ($left.data_type(), $right.data_type()) { - (DataType::FixedSizeList(..), DataType::FixedSizeList(..)) => { - Ok(apply_fixed_numeric_op!( - $left.downcast::().unwrap(), - $right.downcast::().unwrap(), - $op - ) - .into_series()) - } - (DataType::Embedding(..), DataType::Embedding(..)) => { - let physical = apply_fixed_numeric_op!( - &$left.downcast::().unwrap().physical, - &$right.downcast::().unwrap().physical, - $op - ); - let array = - EmbeddingArray::new(Field::new($left.name(), $output_type.clone()), physical); - Ok(array.into_series()) - } - (DataType::FixedShapeTensor(..), DataType::FixedShapeTensor(..)) => { - let physical = apply_fixed_numeric_op!( - &$left.downcast::().unwrap().physical, - &$right.downcast::().unwrap().physical, - $op - ); - let array = FixedShapeTensorArray::new( - Field::new($left.name(), $output_type.clone()), - physical, - ); - Ok(array.into_series()) - } - (left, right) => unimplemented!("cannot add {left} and {right} types"), - } - }}; -} - -macro_rules! binary_op_unimplemented { - ($lhs:expr, $op:expr, $rhs:expr, $output_ty:expr) => { - unimplemented!( - "No implementation for {} {} {} -> {}", - $lhs.data_type(), - $op, - $rhs.data_type(), - $output_ty, - ) - }; -} - -macro_rules! py_numeric_binary_op { - ($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{ - let output_type = - InferDataType::from($self.data_type()).$op(InferDataType::from($rhs.data_type()))?; - let lhs = $self.into_series(); - match &output_type { - #[cfg(feature = "python")] - DataType::Python => Ok(py_binary_op!(lhs, $rhs, $pyop)), - output_type if output_type.is_numeric() => { - with_match_numeric_daft_types!(output_type, |$T| { - cast_downcast_op_into_series!( - lhs, - $rhs, - output_type, - <$T as DaftDataType>::ArrayType, - $op - ) - }) - } - output_type if output_type.is_fixed_size_numeric() => { - fixed_sized_numeric_binary_op!(&lhs, $rhs, output_type, $op) - } - _ => binary_op_unimplemented!(lhs, $pyop, $rhs, output_type), - } - }}; -} - -macro_rules! physical_logic_op { - ($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{ - let output_type = InferDataType::from($self.data_type()) - .logical_op(&InferDataType::from($rhs.data_type()))?; - let lhs = $self.into_series(); - match &output_type { - #[cfg(feature = "python")] - DataType::Boolean => match (&lhs.data_type(), &$rhs.data_type()) { - #[cfg(feature = "python")] - (DataType::Python, _) | (_, DataType::Python) => { - Ok(py_binary_op_bool!(lhs, $rhs, $pyop)) - } - _ => { - cast_downcast_op_into_series!(lhs, $rhs, &DataType::Boolean, BooleanArray, $op) - } - }, - output_type if output_type.is_integer() => { - with_match_integer_daft_types!(output_type, |$T| { - cast_downcast_op_into_series!( - lhs, - $rhs, - output_type, - <$T as DaftDataType>::ArrayType, - $op - ) - }) - } - _ => binary_op_unimplemented!(lhs, $pyop, $rhs, output_type), - } - }}; -} - -macro_rules! physical_compare_op { - ($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{ - let (output_type, intermediate, comp_type) = InferDataType::from($self.data_type()) - .comparison_op(&InferDataType::from($rhs.data_type()))?; - let lhs = $self.into_series(); - let (lhs, rhs) = if let Some(ref it) = intermediate { - (lhs.cast(it)?, $rhs.cast(it)?) - } else { - (lhs, $rhs.clone()) - }; - - if let DataType::Boolean = output_type { - match comp_type { - #[cfg(feature = "python")] - DataType::Python => py_binary_op_bool!(lhs, rhs, $pyop) - .downcast::() - .cloned(), - _ => with_match_comparable_daft_types!(comp_type, |$T| { - cast_downcast_op!(lhs, rhs, &comp_type, <$T as DaftDataType>::ArrayType, $op) - }), - } - } else { - unreachable!() - } - }}; -} - -pub(crate) trait SeriesBinaryOps: SeriesLike { - fn add(&self, rhs: &Series) -> DaftResult { - let output_type = - InferDataType::from(self.data_type()).add(InferDataType::from(rhs.data_type()))?; - let lhs = self.into_series(); - match &output_type { - #[cfg(feature = "python")] - DataType::Python => Ok(py_binary_op!(lhs, rhs, "add")), - DataType::Utf8 => { - cast_downcast_op_into_series!(lhs, rhs, &DataType::Utf8, Utf8Array, add) - } - output_type if output_type.is_numeric() => { - with_match_numeric_daft_types!(output_type, |$T| { - cast_downcast_op_into_series!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, add) - }) - } - output_type if output_type.is_fixed_size_numeric() => { - fixed_sized_numeric_binary_op!(&lhs, rhs, output_type, add) - } - _ => binary_op_unimplemented!(lhs, "+", rhs, output_type), - } - } - fn sub(&self, rhs: &Series) -> DaftResult { - py_numeric_binary_op!(self, rhs, sub, "sub") - } - fn mul(&self, rhs: &Series) -> DaftResult { - py_numeric_binary_op!(self, rhs, mul, "mul") - } - fn div(&self, rhs: &Series) -> DaftResult { - let output_type = - InferDataType::from(self.data_type()).div(InferDataType::from(rhs.data_type()))?; - let lhs = self.into_series(); - match &output_type { - #[cfg(feature = "python")] - DataType::Python => Ok(py_binary_op!(lhs, rhs, "truediv")), - DataType::Float64 => { - cast_downcast_op_into_series!(lhs, rhs, &DataType::Float64, Float64Array, div) - } - output_type if output_type.is_fixed_size_numeric() => { - fixed_sized_numeric_binary_op!(&lhs, rhs, output_type, div) - } - _ => binary_op_unimplemented!(lhs, "/", rhs, output_type), - } - } - fn rem(&self, rhs: &Series) -> DaftResult { - py_numeric_binary_op!(self, rhs, rem, "mod") - } - fn and(&self, rhs: &Series) -> DaftResult { - physical_logic_op!(self, rhs, and, "and_") - } - fn or(&self, rhs: &Series) -> DaftResult { - physical_logic_op!(self, rhs, or, "or_") - } - fn xor(&self, rhs: &Series) -> DaftResult { - physical_logic_op!(self, rhs, xor, "xor") - } - fn equal(&self, rhs: &Series) -> DaftResult { - physical_compare_op!(self, rhs, equal, "eq") - } - fn not_equal(&self, rhs: &Series) -> DaftResult { - physical_compare_op!(self, rhs, not_equal, "ne") - } - fn lt(&self, rhs: &Series) -> DaftResult { - physical_compare_op!(self, rhs, lt, "lt") - } - fn lte(&self, rhs: &Series) -> DaftResult { - physical_compare_op!(self, rhs, lte, "le") - } - fn gt(&self, rhs: &Series) -> DaftResult { - physical_compare_op!(self, rhs, gt, "gt") - } - fn gte(&self, rhs: &Series) -> DaftResult { - physical_compare_op!(self, rhs, gte, "ge") - } -} - -#[cfg(feature = "python")] -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper { - fn add(&self, rhs: &Series) -> DaftResult { - let output_type = - (InferDataType::from(self.data_type()) + InferDataType::from(rhs.data_type()))?; - match rhs.data_type() { - DataType::Duration(..) => { - let days = rhs.duration()?.cast_to_days()?; - let physical_result = self.0.physical.add(&days)?; - physical_result.cast(&output_type) - } - _ => binary_op_unimplemented!(self, "+", rhs, output_type), - } - } - fn sub(&self, rhs: &Series) -> DaftResult { - let output_type = - (InferDataType::from(self.data_type()) - InferDataType::from(rhs.data_type()))?; - match rhs.data_type() { - DataType::Date => { - let physical_result = self.0.physical.sub(&rhs.date()?.physical)?; - physical_result.cast(&output_type) - } - DataType::Duration(..) => { - let days = rhs.duration()?.cast_to_days()?; - let physical_result = self.0.physical.sub(&days)?; - physical_result.cast(&output_type) - } - _ => binary_op_unimplemented!(self, "-", rhs, output_type), - } - } -} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper { - fn add(&self, rhs: &Series) -> DaftResult { - let output_type = - (InferDataType::from(self.data_type()) + InferDataType::from(rhs.data_type()))?; - let lhs = self.0.clone().into_series(); - match rhs.data_type() { - DataType::Timestamp(..) => { - let physical_result = self.0.physical.add(&rhs.timestamp()?.physical)?; - physical_result.cast(&output_type) - } - DataType::Duration(..) => { - let physical_result = self.0.physical.add(&rhs.duration()?.physical)?; - physical_result.cast(&output_type) - } - DataType::Date => { - let days = self.0.cast_to_days()?; - let physical_result = days.add(&rhs.date()?.physical)?; - physical_result.cast(&output_type) - } - _ => binary_op_unimplemented!(lhs, "+", rhs, output_type), - } - } - - fn sub(&self, rhs: &Series) -> DaftResult { - let output_type = - (InferDataType::from(self.data_type()) - InferDataType::from(rhs.data_type()))?; - match rhs.data_type() { - DataType::Duration(..) => { - let physical_result = self.0.physical.sub(&rhs.duration()?.physical)?; - physical_result.cast(&output_type) - } - _ => binary_op_unimplemented!(self, "-", rhs, output_type), - } - } -} - -impl SeriesBinaryOps for ArrayWrapper { - fn add(&self, rhs: &Series) -> DaftResult { - let output_type = - (InferDataType::from(self.data_type()) + InferDataType::from(rhs.data_type()))?; - match rhs.data_type() { - DataType::Duration(..) => { - let physical_result = self.0.physical.add(&rhs.duration()?.physical)?; - physical_result.cast(&output_type) - } - _ => binary_op_unimplemented!(self, "+", rhs, output_type), - } - } - fn sub(&self, rhs: &Series) -> DaftResult { - let output_type = - (InferDataType::from(self.data_type()) - InferDataType::from(rhs.data_type()))?; - match rhs.data_type() { - DataType::Duration(..) => { - let physical_result = self.0.physical.sub(&rhs.duration()?.physical)?; - physical_result.cast(&output_type) - } - DataType::Timestamp(..) => { - let physical_result = self.0.physical.sub(&rhs.timestamp()?.physical)?; - physical_result.cast(&output_type) - } - _ => binary_op_unimplemented!(self, "-", rhs, output_type), - } - } -} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} -impl SeriesBinaryOps for ArrayWrapper {} diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index c210d5cdb2..1506d8cc9e 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -12,7 +12,7 @@ use crate::{ DataArray, }, datatypes::{DaftArrowBackedType, DataType, FixedSizeBinaryArray}, - series::{array_impl::binary_ops::SeriesBinaryOps, series_like::SeriesLike}, + series::series_like::SeriesLike, with_match_integer_daft_types, }; @@ -159,51 +159,6 @@ macro_rules! impl_series_like_for_data_array { None => Ok(self.0.list()?.into_series()), } } - - fn add(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::add(self, rhs) - } - fn sub(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::sub(self, rhs) - } - fn mul(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::mul(self, rhs) - } - fn div(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::div(self, rhs) - } - fn rem(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::rem(self, rhs) - } - - fn and(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::and(self, rhs) - } - fn or(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::or(self, rhs) - } - fn xor(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::xor(self, rhs) - } - - fn equal(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::equal(self, rhs) - } - fn not_equal(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::not_equal(self, rhs) - } - fn lt(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::lt(self, rhs) - } - fn lte(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::lte(self, rhs) - } - fn gt(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::gt(self, rhs) - } - fn gte(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::gte(self, rhs) - } } }; } diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index 9076907579..759af30ccf 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -4,7 +4,7 @@ use super::{ArrayWrapper, IntoSeries, Series}; use crate::{ array::{ops::GroupIndices, prelude::*}, datatypes::prelude::*, - series::{array_impl::binary_ops::SeriesBinaryOps, DaftResult, SeriesLike}, + series::{DaftResult, SeriesLike}, with_match_integer_daft_types, }; @@ -165,53 +165,6 @@ macro_rules! impl_series_like_for_logical_array { ) .into_series()) } - - fn add(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::add(self, rhs) - } - - fn sub(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::sub(self, rhs) - } - - fn mul(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::mul(self, rhs) - } - - fn div(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::div(self, rhs) - } - - fn rem(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::rem(self, rhs) - } - fn and(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::and(self, rhs) - } - fn or(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::or(self, rhs) - } - fn xor(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::xor(self, rhs) - } - fn equal(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::equal(self, rhs) - } - fn not_equal(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::not_equal(self, rhs) - } - fn lt(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::lt(self, rhs) - } - fn lte(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::lte(self, rhs) - } - fn gt(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::gt(self, rhs) - } - fn gte(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::gte(self, rhs) - } } }; } diff --git a/src/daft-core/src/series/array_impl/mod.rs b/src/daft-core/src/series/array_impl/mod.rs index 6a3f0839ad..61f12aa926 100644 --- a/src/daft-core/src/series/array_impl/mod.rs +++ b/src/daft-core/src/series/array_impl/mod.rs @@ -1,4 +1,3 @@ -pub mod binary_ops; pub mod data_array; pub mod logical_array; pub mod nested_array; diff --git a/src/daft-core/src/series/array_impl/nested_array.rs b/src/daft-core/src/series/array_impl/nested_array.rs index 5a6adb8b4c..1bd618e616 100644 --- a/src/daft-core/src/series/array_impl/nested_array.rs +++ b/src/daft-core/src/series/array_impl/nested_array.rs @@ -9,7 +9,7 @@ use crate::{ FixedSizeListArray, ListArray, StructArray, }, datatypes::{BooleanArray, DataType, Field}, - series::{array_impl::binary_ops::SeriesBinaryOps, IntoSeries, Series, SeriesLike}, + series::{IntoSeries, Series, SeriesLike}, with_match_integer_daft_types, }; @@ -148,51 +148,6 @@ macro_rules! impl_series_like_for_nested_arrays { fn str_value(&self, idx: usize) -> DaftResult { self.0.str_value(idx) } - - fn add(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::add(self, rhs) - } - fn sub(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::sub(self, rhs) - } - fn mul(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::mul(self, rhs) - } - fn div(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::div(self, rhs) - } - fn rem(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::rem(self, rhs) - } - - fn and(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::and(self, rhs) - } - fn or(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::or(self, rhs) - } - fn xor(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::xor(self, rhs) - } - - fn equal(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::equal(self, rhs) - } - fn not_equal(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::not_equal(self, rhs) - } - fn lt(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::lt(self, rhs) - } - fn lte(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::lte(self, rhs) - } - fn gt(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::gt(self, rhs) - } - fn gte(&self, rhs: &Series) -> DaftResult { - SeriesBinaryOps::gte(self, rhs) - } } }; } diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index 0aa91d281c..276fdfde87 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -3,6 +3,7 @@ mod from; mod ops; mod serdes; mod series_like; +mod utils; use std::sync::Arc; pub use array_impl::IntoSeries; @@ -37,6 +38,12 @@ impl PartialEq for Series { } impl Series { + /// Exports this Series into an Arrow arrow that is corrected for the Arrow type system. + /// For example, Daft's TimestampArray is a logical type that is backed by an Int64Array Physical array. + /// If we were to call `.as_arrow()` or `.physical`on the TimestampArray, we would get an Int64Array that represented the time units. + /// However if we want to export our Timestamp array to another arrow system like arrow2 kernels or python, duckdb or more. + /// We should convert it back to the canonical arrow dtype of Timestamp rather than Int64. + /// To get the internal physical type without conversion, see `as_arrow()`. pub fn to_arrow(&self) -> Box { self.inner.to_arrow() } diff --git a/src/daft-core/src/series/ops/arithmetic.rs b/src/daft-core/src/series/ops/arithmetic.rs index b49eba83d4..1c9b1df7ea 100644 --- a/src/daft-core/src/series/ops/arithmetic.rs +++ b/src/daft-core/src/series/ops/arithmetic.rs @@ -1,18 +1,19 @@ use std::ops::{Add, Div, Mul, Rem, Sub}; use common_error::DaftResult; +use daft_schema::prelude::*; -use crate::series::Series; +#[cfg(feature = "python")] +use crate::series::utils::python_fn::run_python_binary_operator_fn; +use crate::{ + array::prelude::*, + datatypes::{InferDataType, Utf8Array}, + series::{utils::cast::cast_downcast_op, IntoSeries, Series}, + with_match_numeric_daft_types, +}; -macro_rules! impl_arithmetic_for_series { +macro_rules! impl_arithmetic_ref_for_series { ($trait:ident, $op:ident) => { - impl $trait for &Series { - type Output = DaftResult; - fn $op(self, rhs: Self) -> Self::Output { - self.inner.$op(rhs) - } - } - impl $trait for Series { type Output = DaftResult; fn $op(self, rhs: Self) -> Self::Output { @@ -22,11 +23,262 @@ macro_rules! impl_arithmetic_for_series { }; } -impl_arithmetic_for_series!(Add, add); -impl_arithmetic_for_series!(Sub, sub); -impl_arithmetic_for_series!(Mul, mul); -impl_arithmetic_for_series!(Div, div); -impl_arithmetic_for_series!(Rem, rem); +macro_rules! arithmetic_op_not_implemented { + ($lhs:expr, $op:expr, $rhs:expr, $output_ty:expr) => { + unimplemented!( + "No implementation for {} {} {} -> {}", + $lhs.data_type(), + $op, + $rhs.data_type(), + $output_ty, + ) + }; +} + +impl Add for &Series { + type Output = DaftResult; + fn add(self, rhs: Self) -> Self::Output { + let output_type = + InferDataType::from(self.data_type()).add(InferDataType::from(rhs.data_type()))?; + let lhs = self; + match &output_type { + #[cfg(feature = "python")] + DataType::Python => run_python_binary_operator_fn(lhs, rhs, "add"), + DataType::Utf8 => { + Ok(cast_downcast_op!(lhs, rhs, &DataType::Utf8, Utf8Array, add)?.into_series()) + } + output_type if output_type.is_numeric() => { + with_match_numeric_daft_types!(output_type, |$T| { + Ok(cast_downcast_op!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, add)?.into_series()) + }) + } + output_type if output_type.is_fixed_size_numeric() => { + fixed_size_binary_op(lhs, rhs, output_type, FixedSizeBinaryOp::Add) + } + output_type + if output_type.is_temporal() || matches!(output_type, DataType::Duration(..)) => + { + match (self.data_type(), rhs.data_type()) { + (DataType::Date, DataType::Duration(..)) => { + let days = rhs.duration()?.cast_to_days()?; + let physical_result = self.date()?.physical.add(&days)?; + physical_result.cast(output_type) + } + (DataType::Duration(..), DataType::Date) => { + let days = lhs.duration()?.cast_to_days()?; + let physical_result = days.add(&rhs.date()?.physical)?; + physical_result.cast(output_type) + } + (DataType::Duration(..), DataType::Duration(..)) => { + let physical_result = + lhs.duration()?.physical.add(&rhs.duration()?.physical)?; + physical_result.cast(output_type) + } + (DataType::Timestamp(..), DataType::Duration(..)) => { + let physical_result = + self.timestamp()?.physical.add(&rhs.duration()?.physical)?; + physical_result.cast(output_type) + } + (DataType::Duration(..), DataType::Timestamp(..)) => { + let physical_result = + lhs.duration()?.physical.add(&rhs.timestamp()?.physical)?; + physical_result.cast(output_type) + } + _ => arithmetic_op_not_implemented!(self, "+", rhs, output_type), + } + } + _ => arithmetic_op_not_implemented!(self, "+", rhs, output_type), + } + } +} + +impl Sub for &Series { + type Output = DaftResult; + fn sub(self, rhs: Self) -> Self::Output { + let output_type = + InferDataType::from(self.data_type()).sub(InferDataType::from(rhs.data_type()))?; + let lhs = self; + match &output_type { + #[cfg(feature = "python")] + DataType::Python => run_python_binary_operator_fn(lhs, rhs, "sub"), + output_type if output_type.is_numeric() => { + with_match_numeric_daft_types!(output_type, |$T| { + Ok(cast_downcast_op!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, sub)?.into_series()) + }) + } + output_type + if output_type.is_temporal() || matches!(output_type, DataType::Duration(..)) => + { + match (self.data_type(), rhs.data_type()) { + (DataType::Date, DataType::Duration(..)) => { + let days = rhs.duration()?.cast_to_days()?; + let physical_result = self.date()?.physical.sub(&days)?; + physical_result.cast(output_type) + } + (DataType::Date, DataType::Date) => { + let physical_result = self.date()?.physical.sub(&rhs.date()?.physical)?; + physical_result.cast(output_type) + } + (DataType::Duration(..), DataType::Duration(..)) => { + let physical_result = + lhs.duration()?.physical.sub(&rhs.duration()?.physical)?; + physical_result.cast(output_type) + } + (DataType::Timestamp(..), DataType::Duration(..)) => { + let physical_result = + self.timestamp()?.physical.sub(&rhs.duration()?.physical)?; + physical_result.cast(output_type) + } + (DataType::Timestamp(..), DataType::Timestamp(..)) => { + let physical_result = + self.timestamp()?.physical.sub(&rhs.timestamp()?.physical)?; + physical_result.cast(output_type) + } + _ => arithmetic_op_not_implemented!(self, "-", rhs, output_type), + } + } + output_type if output_type.is_fixed_size_numeric() => { + fixed_size_binary_op(lhs, rhs, output_type, FixedSizeBinaryOp::Sub) + } + _ => arithmetic_op_not_implemented!(self, "-", rhs, output_type), + } + } +} + +impl Mul for &Series { + type Output = DaftResult; + fn mul(self, rhs: Self) -> Self::Output { + let output_type = + InferDataType::from(self.data_type()).mul(InferDataType::from(rhs.data_type()))?; + let lhs = self; + match &output_type { + #[cfg(feature = "python")] + DataType::Python => run_python_binary_operator_fn(lhs, rhs, "mul"), + output_type if output_type.is_numeric() => { + with_match_numeric_daft_types!(output_type, |$T| { + Ok(cast_downcast_op!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, mul)?.into_series()) + }) + } + output_type if output_type.is_fixed_size_numeric() => { + fixed_size_binary_op(lhs, rhs, output_type, FixedSizeBinaryOp::Mul) + } + _ => arithmetic_op_not_implemented!(self, "*", rhs, output_type), + } + } +} + +impl Div for &Series { + type Output = DaftResult; + fn div(self, rhs: Self) -> Self::Output { + let output_type = + InferDataType::from(self.data_type()).div(InferDataType::from(rhs.data_type()))?; + let lhs = self; + match &output_type { + #[cfg(feature = "python")] + DataType::Python => run_python_binary_operator_fn(lhs, rhs, "truediv"), + DataType::Float64 => { + Ok( + cast_downcast_op!(lhs, rhs, &DataType::Float64, Float64Array, div)? + .into_series(), + ) + } + output_type if output_type.is_fixed_size_numeric() => { + fixed_size_binary_op(lhs, rhs, output_type, FixedSizeBinaryOp::Div) + } + _ => arithmetic_op_not_implemented!(self, "/", rhs, output_type), + } + } +} + +impl Rem for &Series { + type Output = DaftResult; + fn rem(self, rhs: Self) -> Self::Output { + let output_type = + InferDataType::from(self.data_type()).rem(InferDataType::from(rhs.data_type()))?; + let lhs = self; + match &output_type { + #[cfg(feature = "python")] + DataType::Python => run_python_binary_operator_fn(lhs, rhs, "mod"), + output_type if output_type.is_numeric() => { + with_match_numeric_daft_types!(output_type, |$T| { + Ok(cast_downcast_op!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, rem)?.into_series()) + }) + } + output_type if output_type.is_fixed_size_numeric() => { + fixed_size_binary_op(lhs, rhs, output_type, FixedSizeBinaryOp::Rem) + } + _ => arithmetic_op_not_implemented!(self, "%", rhs, output_type), + } + } +} +enum FixedSizeBinaryOp { + Add, + Sub, + Mul, + Div, + Rem, +} + +fn fixed_size_binary_op( + left: &Series, + right: &Series, + output_type: &DataType, + op: FixedSizeBinaryOp, +) -> DaftResult { + fn run_fixed_size_binary_op(lhs: &A, rhs: &A, op: FixedSizeBinaryOp) -> DaftResult + where + for<'a> &'a A: Add> + + Sub> + + Mul> + + Div> + + Rem>, + { + match op { + FixedSizeBinaryOp::Add => lhs.add(rhs), + FixedSizeBinaryOp::Sub => lhs.sub(rhs), + FixedSizeBinaryOp::Mul => lhs.mul(rhs), + FixedSizeBinaryOp::Div => lhs.div(rhs), + FixedSizeBinaryOp::Rem => lhs.rem(rhs), + } + } + + match (left.data_type(), right.data_type()) { + (DataType::FixedSizeList(..), DataType::FixedSizeList(..)) => { + let array = run_fixed_size_binary_op( + left.downcast::().unwrap(), + right.downcast::().unwrap(), + op, + )?; + Ok(array.into_series()) + } + (DataType::Embedding(..), DataType::Embedding(..)) => { + let physical = run_fixed_size_binary_op( + &left.downcast::().unwrap().physical, + &right.downcast::().unwrap().physical, + op, + )?; + let array = EmbeddingArray::new(Field::new(left.name(), output_type.clone()), physical); + Ok(array.into_series()) + } + (DataType::FixedShapeTensor(..), DataType::FixedShapeTensor(..)) => { + let physical = run_fixed_size_binary_op( + &left.downcast::().unwrap().physical, + &right.downcast::().unwrap().physical, + op, + )?; + let array = + FixedShapeTensorArray::new(Field::new(left.name(), output_type.clone()), physical); + Ok(array.into_series()) + } + (left, right) => unimplemented!("cannot add {left} and {right} types"), + } +} + +impl_arithmetic_ref_for_series!(Add, add); +impl_arithmetic_ref_for_series!(Sub, sub); +impl_arithmetic_ref_for_series!(Mul, mul); +impl_arithmetic_ref_for_series!(Div, div); +impl_arithmetic_ref_for_series!(Rem, rem); #[cfg(test)] mod tests { diff --git a/src/daft-core/src/series/ops/between.rs b/src/daft-core/src/series/ops/between.rs index 4e3d8c89d5..6c53cbb86c 100644 --- a/src/daft-core/src/series/ops/between.rs +++ b/src/daft-core/src/series/ops/between.rs @@ -1,7 +1,7 @@ use common_error::DaftResult; #[cfg(feature = "python")] -use crate::series::ops::py_between_op_utilfn; +use crate::series::utils::python_fn::py_between_op_utilfn; use crate::{ array::ops::DaftBetween, datatypes::{BooleanArray, DataType, InferDataType}, diff --git a/src/daft-core/src/series/ops/comparison.rs b/src/daft-core/src/series/ops/comparison.rs index 67ac7c66ec..2d0fd65c79 100644 --- a/src/daft-core/src/series/ops/comparison.rs +++ b/src/daft-core/src/series/ops/comparison.rs @@ -1,34 +1,68 @@ +use std::borrow::Cow; + use common_error::DaftResult; +use daft_schema::prelude::DataType; +#[cfg(feature = "python")] +use crate::series::utils::python_fn::run_python_binary_bool_operator; use crate::{ - array::ops::{DaftCompare, DaftLogical}, - datatypes::BooleanArray, - series::Series, + array::ops::DaftCompare, + datatypes::{BooleanArray, InferDataType}, + series::{utils::cast::cast_downcast_op, Series}, + with_match_comparable_daft_types, }; -macro_rules! call_inner { - ($fname:ident) => { - fn $fname(&self, other: &Series) -> Self::Output { - self.inner.$fname(other) +macro_rules! impl_compare_method { + ($fname:ident, $pyoperator:expr) => { + fn $fname(&self, rhs: &Series) -> Self::Output { + let lhs = self; + let (output_type, intermediate_type, comparison_type) = + InferDataType::from(self.data_type()) + .comparison_op(&InferDataType::from(rhs.data_type()))?; + assert_eq!( + output_type, + DataType::Boolean, + "All {} Comparisons should result in an Boolean output type, got {output_type}", + stringify!($fname) + ); + let (lhs, rhs) = if let Some(intermediate_type) = intermediate_type { + ( + Cow::Owned(lhs.cast(&intermediate_type)?), + Cow::Owned(rhs.cast(&intermediate_type)?), + ) + } else { + (Cow::Borrowed(lhs), Cow::Borrowed(rhs)) + }; + match comparison_type { + #[cfg(feature = "python")] + DataType::Python => { + let output = + run_python_binary_bool_operator(&lhs, &rhs, stringify!($pyoperator))?; + let bool_array = output + .bool() + .expect("We expected a Boolean Series from this Python Comparison"); + Ok(bool_array.clone()) + } + _ => with_match_comparable_daft_types!(comparison_type, |$T| { + cast_downcast_op!( + lhs, + rhs, + &comparison_type, + <$T as DaftDataType>::ArrayType, + $fname + ) + }), + } } }; } impl DaftCompare<&Self> for Series { type Output = DaftResult; - - call_inner!(equal); - call_inner!(not_equal); - call_inner!(lt); - call_inner!(lte); - call_inner!(gt); - call_inner!(gte); -} - -impl DaftLogical<&Self> for Series { - type Output = DaftResult; - - call_inner!(and); - call_inner!(or); - call_inner!(xor); + impl_compare_method!(equal, eq); + impl_compare_method!(not_equal, ne); + impl_compare_method!(lt, lt); + impl_compare_method!(lte, le); + impl_compare_method!(gt, gt); + impl_compare_method!(gte, ge); } diff --git a/src/daft-core/src/series/ops/is_in.rs b/src/daft-core/src/series/ops/is_in.rs index 7b2386b745..d6655d4bb9 100644 --- a/src/daft-core/src/series/ops/is_in.rs +++ b/src/daft-core/src/series/ops/is_in.rs @@ -1,7 +1,7 @@ use common_error::DaftResult; #[cfg(feature = "python")] -use crate::series::ops::py_membership_op_utilfn; +use crate::series::utils::python_fn::py_membership_op_utilfn; use crate::{ array::ops::DaftIsIn, datatypes::{BooleanArray, DataType, InferDataType}, diff --git a/src/daft-core/src/series/ops/logical.rs b/src/daft-core/src/series/ops/logical.rs new file mode 100644 index 0000000000..02dc1bfe3b --- /dev/null +++ b/src/daft-core/src/series/ops/logical.rs @@ -0,0 +1,117 @@ +use common_error::DaftResult; +use daft_schema::dtype::DataType; + +#[cfg(feature = "python")] +use crate::series::utils::python_fn::run_python_binary_bool_operator; +use crate::{ + array::ops::DaftLogical, + datatypes::InferDataType, + prelude::BooleanArray, + series::{utils::cast::cast_downcast_op, IntoSeries, Series}, + with_match_integer_daft_types, +}; +macro_rules! logical_op_not_implemented { + ($self:expr, $rhs:expr, $op:ident) => {{ + let left_dtype = $self.data_type(); + let right_dtype = $rhs.data_type(); + let op_name = stringify!($op); + return Err(common_error::DaftError::ComputeError(format!( + "Logical Op: {op_name} not implemented for {left_dtype}, {right_dtype}" + ))); + }}; +} + +impl DaftLogical<&Self> for Series { + type Output = DaftResult; + + fn and(&self, rhs: &Self) -> Self::Output { + let lhs = self; + let output_type = InferDataType::from(lhs.data_type()) + .logical_op(&InferDataType::from(rhs.data_type()))?; + match &output_type { + DataType::Boolean => match (lhs.data_type(), rhs.data_type()) { + #[cfg(feature = "python")] + (DataType::Python, _) | (_, DataType::Python) => { + run_python_binary_bool_operator(lhs, rhs, "and_") + } + _ => Ok( + cast_downcast_op!(lhs, rhs, &DataType::Boolean, BooleanArray, and)? + .into_series(), + ), + }, + output_type if output_type.is_integer() => { + with_match_integer_daft_types!(output_type, |$T| { + Ok(cast_downcast_op!( + self, + rhs, + output_type, + <$T as DaftDataType>::ArrayType, + and + )?.into_series()) + }) + } + + _ => logical_op_not_implemented!(self, rhs, and), + } + } + + fn or(&self, rhs: &Self) -> Self::Output { + let lhs = self; + let output_type = InferDataType::from(self.data_type()) + .logical_op(&InferDataType::from(rhs.data_type()))?; + match &output_type { + DataType::Boolean => match (lhs.data_type(), rhs.data_type()) { + #[cfg(feature = "python")] + (DataType::Python, _) | (_, DataType::Python) => { + run_python_binary_bool_operator(lhs, rhs, "or_") + } + _ => Ok( + cast_downcast_op!(lhs, rhs, &DataType::Boolean, BooleanArray, or)? + .into_series(), + ), + }, + output_type if output_type.is_integer() => { + with_match_integer_daft_types!(output_type, |$T| { + Ok(cast_downcast_op!( + self, + rhs, + output_type, + <$T as DaftDataType>::ArrayType, + or + )?.into_series()) + }) + } + _ => logical_op_not_implemented!(self, rhs, or), + } + } + + fn xor(&self, rhs: &Self) -> Self::Output { + let lhs = self; + let output_type = InferDataType::from(self.data_type()) + .logical_op(&InferDataType::from(rhs.data_type()))?; + match &output_type { + DataType::Boolean => match (lhs.data_type(), rhs.data_type()) { + #[cfg(feature = "python")] + (DataType::Python, _) | (_, DataType::Python) => { + run_python_binary_bool_operator(lhs, rhs, "xor") + } + _ => Ok( + cast_downcast_op!(lhs, rhs, &DataType::Boolean, BooleanArray, xor)? + .into_series(), + ), + }, + output_type if output_type.is_integer() => { + with_match_integer_daft_types!(output_type, |$T| { + Ok(cast_downcast_op!( + self, + rhs, + output_type, + <$T as DaftDataType>::ArrayType, + xor + )?.into_series()) + }) + } + _ => logical_op_not_implemented!(self, rhs, xor), + } + } +} diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index 59f04fcd6b..1c01a200bb 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -25,6 +25,7 @@ pub mod is_in; pub mod len; pub mod list; pub mod log; +pub mod logical; pub mod map; pub mod minhash; pub mod not; @@ -53,144 +54,3 @@ pub fn cast_series_to_supertype(series: &[&Series]) -> DaftResult> { series.iter().map(|s| s.cast(&supertype)).collect() } - -#[cfg(feature = "python")] -macro_rules! py_binary_op_utilfn { - ($lhs:expr, $rhs:expr, $pyoperator:expr, $utilfn:expr) => {{ - use pyo3::prelude::*; - - use crate::{datatypes::DataType, python::PySeries}; - - let lhs = $lhs.cast(&DataType::Python)?; - let rhs = $rhs.cast(&DataType::Python)?; - - let (lhs, rhs) = match (lhs.len(), rhs.len()) { - (a, b) if a == b => (lhs, rhs), - (a, 1) => (lhs, rhs.broadcast(a)?), - (1, b) => (lhs.broadcast(b)?, rhs), - (a, b) => panic!("Cannot apply operation on arrays of different lengths: {a} vs {b}"), - }; - - let left_pylist = PySeries::from(lhs.clone()).to_pylist()?; - let right_pylist = PySeries::from(rhs.clone()).to_pylist()?; - - let result_series: Series = Python::with_gil(|py| -> PyResult { - let py_operator = PyModule::import_bound(py, pyo3::intern!(py, "operator"))? - .getattr(pyo3::intern!(py, $pyoperator))?; - - let result_pylist = PyModule::import_bound(py, pyo3::intern!(py, "daft.utils"))? - .getattr(pyo3::intern!(py, $utilfn))? - .call1((py_operator, left_pylist, right_pylist))?; - - PyModule::import_bound(py, pyo3::intern!(py, "daft.series"))? - .getattr(pyo3::intern!(py, "Series"))? - .getattr(pyo3::intern!(py, "from_pylist"))? - .call1((result_pylist, lhs.name(), pyo3::intern!(py, "disallow")))? - .getattr(pyo3::intern!(py, "_series"))? - .extract() - })? - .into(); - - result_series - }}; -} -#[cfg(feature = "python")] -pub(super) use py_binary_op_utilfn; - -#[cfg(feature = "python")] -pub(super) fn py_membership_op_utilfn(lhs: &Series, rhs: &Series) -> DaftResult { - use pyo3::prelude::*; - - use crate::{datatypes::DataType, python::PySeries}; - - let lhs_casted = lhs.cast(&DataType::Python)?; - let rhs_casted = rhs.cast(&DataType::Python)?; - - let left_pylist = PySeries::from(lhs_casted.clone()).to_pylist()?; - let right_pylist = PySeries::from(rhs_casted.clone()).to_pylist()?; - - let result_series: Series = Python::with_gil(|py| -> PyResult { - let result_pylist = PyModule::import_bound(py, pyo3::intern!(py, "daft.utils"))? - .getattr(pyo3::intern!(py, "python_list_membership_check"))? - .call1((left_pylist, right_pylist))?; - - PyModule::import_bound(py, pyo3::intern!(py, "daft.series"))? - .getattr(pyo3::intern!(py, "Series"))? - .getattr(pyo3::intern!(py, "from_pylist"))? - .call1(( - result_pylist, - lhs_casted.name(), - pyo3::intern!(py, "disallow"), - ))? - .getattr(pyo3::intern!(py, "_series"))? - .extract() - })? - .into(); - - Ok(result_series) -} - -#[cfg(feature = "python")] -pub(super) fn py_between_op_utilfn( - value: &Series, - lower: &Series, - upper: &Series, -) -> DaftResult { - use pyo3::prelude::*; - - use crate::{datatypes::DataType, python::PySeries}; - - let value_casted = value.cast(&DataType::Python)?; - let lower_casted = lower.cast(&DataType::Python)?; - let upper_casted = upper.cast(&DataType::Python)?; - - let (value_casted, lower_casted, upper_casted) = - match (value_casted.len(), lower_casted.len(), upper_casted.len()) { - (a, b, c) if a == b && b == c => (value_casted, lower_casted, upper_casted), - (1, a, b) if a == b => (value_casted.broadcast(a)?, lower_casted, upper_casted), - (a, 1, b) if a == b => (value_casted, lower_casted.broadcast(a)?, upper_casted), - (a, b, 1) if a == b => (value_casted, lower_casted, upper_casted.broadcast(a)?), - (a, 1, 1) => ( - value_casted, - lower_casted.broadcast(a)?, - upper_casted.broadcast(a)?, - ), - (1, a, 1) => ( - value_casted.broadcast(a)?, - lower_casted, - upper_casted.broadcast(a)?, - ), - (1, 1, a) => ( - value_casted.broadcast(a)?, - lower_casted.broadcast(a)?, - upper_casted, - ), - (a, b, c) => { - panic!("Cannot apply operation on arrays of different lengths: {a} vs {b} vs {c}") - } - }; - - let value_pylist = PySeries::from(value_casted.clone()).to_pylist()?; - let lower_pylist = PySeries::from(lower_casted.clone()).to_pylist()?; - let upper_pylist = PySeries::from(upper_casted.clone()).to_pylist()?; - - let result_series: Series = Python::with_gil(|py| -> PyResult { - let result_pylist = PyModule::import_bound(py, pyo3::intern!(py, "daft.utils"))? - .getattr(pyo3::intern!(py, "python_list_between_check"))? - .call1((value_pylist, lower_pylist, upper_pylist))?; - - PyModule::import_bound(py, pyo3::intern!(py, "daft.series"))? - .getattr(pyo3::intern!(py, "Series"))? - .getattr(pyo3::intern!(py, "from_pylist"))? - .call1(( - result_pylist, - value_casted.name(), - pyo3::intern!(py, "disallow"), - ))? - .getattr(pyo3::intern!(py, "_series"))? - .extract() - })? - .into(); - - Ok(result_series) -} diff --git a/src/daft-core/src/series/series_like.rs b/src/daft-core/src/series/series_like.rs index 463892c8bd..9d152693d9 100644 --- a/src/daft-core/src/series/series_like.rs +++ b/src/daft-core/src/series/series_like.rs @@ -34,18 +34,4 @@ pub trait SeriesLike: Send + Sync + Any + std::fmt::Debug { fn slice(&self, start: usize, end: usize) -> DaftResult; fn take(&self, idx: &Series) -> DaftResult; fn str_value(&self, idx: usize) -> DaftResult; - fn add(&self, rhs: &Series) -> DaftResult; - fn sub(&self, rhs: &Series) -> DaftResult; - fn mul(&self, rhs: &Series) -> DaftResult; - fn div(&self, rhs: &Series) -> DaftResult; - fn rem(&self, rhs: &Series) -> DaftResult; - fn and(&self, rhs: &Series) -> DaftResult; - fn or(&self, rhs: &Series) -> DaftResult; - fn xor(&self, rhs: &Series) -> DaftResult; - fn equal(&self, rhs: &Series) -> DaftResult; - fn not_equal(&self, rhs: &Series) -> DaftResult; - fn lt(&self, rhs: &Series) -> DaftResult; - fn lte(&self, rhs: &Series) -> DaftResult; - fn gt(&self, rhs: &Series) -> DaftResult; - fn gte(&self, rhs: &Series) -> DaftResult; } diff --git a/src/daft-core/src/series/utils/mod.rs b/src/daft-core/src/series/utils/mod.rs new file mode 100644 index 0000000000..a262af9755 --- /dev/null +++ b/src/daft-core/src/series/utils/mod.rs @@ -0,0 +1,14 @@ +#[cfg(feature = "python")] +pub(super) mod python_fn; +pub(crate) mod cast { + macro_rules! cast_downcast_op { + ($lhs:expr, $rhs:expr, $ty_expr:expr, $ty_type:ty, $op:ident) => {{ + let lhs = $lhs.cast($ty_expr)?; + let rhs = $rhs.cast($ty_expr)?; + let lhs = lhs.downcast::<$ty_type>()?; + let rhs = rhs.downcast::<$ty_type>()?; + lhs.$op(rhs) + }}; + } + pub(crate) use cast_downcast_op; +} diff --git a/src/daft-core/src/series/utils/python_fn.rs b/src/daft-core/src/series/utils/python_fn.rs new file mode 100644 index 0000000000..2fb9112775 --- /dev/null +++ b/src/daft-core/src/series/utils/python_fn.rs @@ -0,0 +1,157 @@ +use common_error::DaftResult; + +use crate::series::Series; + +pub(crate) fn run_python_binary_operator_fn( + lhs: &Series, + rhs: &Series, + operator_fn: &str, +) -> DaftResult { + python_binary_op_with_utilfn(lhs, rhs, operator_fn, "map_operator_arrow_semantics") +} + +pub(crate) fn run_python_binary_bool_operator( + lhs: &Series, + rhs: &Series, + operator_fn: &str, +) -> DaftResult { + python_binary_op_with_utilfn(lhs, rhs, operator_fn, "map_operator_arrow_semantics_bool") +} + +fn python_binary_op_with_utilfn( + lhs: &Series, + rhs: &Series, + operator_fn: &str, + util_fn: &str, +) -> DaftResult { + use pyo3::prelude::*; + + use crate::{datatypes::DataType, python::PySeries}; + + let lhs = lhs.cast(&DataType::Python)?; + let rhs = rhs.cast(&DataType::Python)?; + + let (lhs, rhs) = match (lhs.len(), rhs.len()) { + (a, b) if a == b => (lhs, rhs), + (a, 1) => (lhs, rhs.broadcast(a)?), + (1, b) => (lhs.broadcast(b)?, rhs), + (a, b) => panic!("Cannot apply operation on arrays of different lengths: {a} vs {b}"), + }; + + let left_pylist = PySeries::from(lhs.clone()).to_pylist()?; + let right_pylist = PySeries::from(rhs.clone()).to_pylist()?; + + let result_series: Series = Python::with_gil(|py| -> PyResult { + let py_operator = + PyModule::import_bound(py, pyo3::intern!(py, "operator"))?.getattr(operator_fn)?; + + let result_pylist = PyModule::import_bound(py, pyo3::intern!(py, "daft.utils"))? + .getattr(util_fn)? + .call1((py_operator, left_pylist, right_pylist))?; + + PyModule::import_bound(py, pyo3::intern!(py, "daft.series"))? + .getattr(pyo3::intern!(py, "Series"))? + .getattr(pyo3::intern!(py, "from_pylist"))? + .call1((result_pylist, lhs.name(), pyo3::intern!(py, "disallow")))? + .getattr(pyo3::intern!(py, "_series"))? + .extract() + })? + .into(); + Ok(result_series) +} + +pub(crate) fn py_membership_op_utilfn(lhs: &Series, rhs: &Series) -> DaftResult { + use pyo3::prelude::*; + + use crate::{datatypes::DataType, python::PySeries}; + + let lhs_casted = lhs.cast(&DataType::Python)?; + let rhs_casted = rhs.cast(&DataType::Python)?; + + let left_pylist = PySeries::from(lhs_casted.clone()).to_pylist()?; + let right_pylist = PySeries::from(rhs_casted.clone()).to_pylist()?; + + let result_series: Series = Python::with_gil(|py| -> PyResult { + let result_pylist = PyModule::import_bound(py, pyo3::intern!(py, "daft.utils"))? + .getattr(pyo3::intern!(py, "python_list_membership_check"))? + .call1((left_pylist, right_pylist))?; + + PyModule::import_bound(py, pyo3::intern!(py, "daft.series"))? + .getattr(pyo3::intern!(py, "Series"))? + .getattr(pyo3::intern!(py, "from_pylist"))? + .call1(( + result_pylist, + lhs_casted.name(), + pyo3::intern!(py, "disallow"), + ))? + .getattr(pyo3::intern!(py, "_series"))? + .extract() + })? + .into(); + + Ok(result_series) +} + +pub(crate) fn py_between_op_utilfn( + value: &Series, + lower: &Series, + upper: &Series, +) -> DaftResult { + use pyo3::prelude::*; + + use crate::{datatypes::DataType, python::PySeries}; + + let value_casted = value.cast(&DataType::Python)?; + let lower_casted = lower.cast(&DataType::Python)?; + let upper_casted = upper.cast(&DataType::Python)?; + + let (value_casted, lower_casted, upper_casted) = + match (value_casted.len(), lower_casted.len(), upper_casted.len()) { + (a, b, c) if a == b && b == c => (value_casted, lower_casted, upper_casted), + (1, a, b) if a == b => (value_casted.broadcast(a)?, lower_casted, upper_casted), + (a, 1, b) if a == b => (value_casted, lower_casted.broadcast(a)?, upper_casted), + (a, b, 1) if a == b => (value_casted, lower_casted, upper_casted.broadcast(a)?), + (a, 1, 1) => ( + value_casted, + lower_casted.broadcast(a)?, + upper_casted.broadcast(a)?, + ), + (1, a, 1) => ( + value_casted.broadcast(a)?, + lower_casted, + upper_casted.broadcast(a)?, + ), + (1, 1, a) => ( + value_casted.broadcast(a)?, + lower_casted.broadcast(a)?, + upper_casted, + ), + (a, b, c) => { + panic!("Cannot apply operation on arrays of different lengths: {a} vs {b} vs {c}") + } + }; + + let value_pylist = PySeries::from(value_casted.clone()).to_pylist()?; + let lower_pylist = PySeries::from(lower_casted.clone()).to_pylist()?; + let upper_pylist = PySeries::from(upper_casted.clone()).to_pylist()?; + + let result_series: Series = Python::with_gil(|py| -> PyResult { + let result_pylist = PyModule::import_bound(py, pyo3::intern!(py, "daft.utils"))? + .getattr(pyo3::intern!(py, "python_list_between_check"))? + .call1((value_pylist, lower_pylist, upper_pylist))?; + + PyModule::import_bound(py, pyo3::intern!(py, "daft.series"))? + .getattr(pyo3::intern!(py, "Series"))? + .getattr(pyo3::intern!(py, "from_pylist"))? + .call1(( + result_pylist, + value_casted.name(), + pyo3::intern!(py, "disallow"), + ))? + .getattr(pyo3::intern!(py, "_series"))? + .extract() + })? + .into(); + + Ok(result_series) +} diff --git a/src/daft-core/src/utils/mod.rs b/src/daft-core/src/utils/mod.rs index b270516ebd..2e039e6953 100644 --- a/src/daft-core/src/utils/mod.rs +++ b/src/daft-core/src/utils/mod.rs @@ -3,15 +3,3 @@ pub mod display; pub mod dyn_compare; pub mod identity_hash_set; pub mod supertype; - -#[macro_export] -macro_rules! impl_binary_trait_by_reference { - ($ty:ty, $trait:ident, $fname:ident) => { - impl $trait for $ty { - type Output = DaftResult<$ty>; - fn $fname(self, other: Self) -> Self::Output { - (&self).$fname(&other) - } - } - }; -} diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index 378611851a..adbb2830e7 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -9,6 +9,8 @@ use common_resource_request::ResourceRequest; use common_treenode::{TreeNode, TreeNodeRecursion}; use daft_core::datatypes::DataType; use itertools::Itertools; +#[cfg(feature = "python")] +use pyo3::{Py, PyAny}; pub use runtime_py_object::RuntimePyObject; use serde::{Deserialize, Serialize}; pub use udf_runtime_binding::UDFRuntimeBinding; @@ -180,7 +182,7 @@ pub fn get_concurrency(exprs: &[ExprRef]) -> usize { #[cfg(feature = "python")] pub fn bind_stateful_udfs( expr: ExprRef, - initialized_funcs: &HashMap>, + initialized_funcs: &HashMap>, ) -> DaftResult { expr.transform(|e| match e.as_ref() { Expr::Function { @@ -213,7 +215,9 @@ pub fn bind_stateful_udfs( /// Helper function that extracts all PartialStatefulUDF python objects from a given expression tree #[cfg(feature = "python")] -pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap> { +pub fn extract_partial_stateful_udf_py( + expr: ExprRef, +) -> HashMap, Option>)> { let mut py_partial_udfs = HashMap::new(); expr.apply(|child| { if let Expr::Function { @@ -221,12 +225,19 @@ pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap), } impl Eq for LiteralValue {} @@ -112,6 +115,12 @@ impl Hash for LiteralValue { } #[cfg(feature = "python")] Python(py_obj) => py_obj.hash(state), + Struct(entries) => { + entries.iter().for_each(|(v, f)| { + v.hash(state); + f.hash(state); + }); + } } } } @@ -143,6 +152,16 @@ impl Display for LiteralValue { Python::with_gil(|py| pyobj.0.call_method0(py, pyo3::intern!(py, "__str__"))) .unwrap() }), + Struct(entries) => { + write!(f, "Struct(")?; + for (i, (field, v)) in entries.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}: {}", field.name, v)?; + } + write!(f, ")") + } } } } @@ -169,6 +188,7 @@ impl LiteralValue { Series(series) => series.data_type().clone(), #[cfg(feature = "python")] Python(_) => DataType::Python, + Struct(entries) => DataType::Struct(entries.keys().cloned().collect()), } } @@ -203,6 +223,13 @@ impl LiteralValue { Series(series) => series.clone().rename("literal"), #[cfg(feature = "python")] Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), + Struct(entries) => { + let struct_dtype = DataType::Struct(entries.keys().cloned().collect()); + let struct_field = Field::new("literal", struct_dtype); + + let values = entries.values().map(|v| v.to_series()).collect(); + StructArray::new(struct_field, values, None).into_series() + } }; result } @@ -235,6 +262,7 @@ impl LiteralValue { Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err, #[cfg(feature = "python")] Python(..) => display_sql_err, + Struct(..) => display_sql_err, } } @@ -304,49 +332,64 @@ impl LiteralValue { } } -pub trait Literal { +pub trait Literal: Sized { /// [Literal](Expr::Literal) expression. - fn lit(self) -> ExprRef; + fn lit(self) -> ExprRef { + Expr::Literal(self.literal_value()).into() + } + fn literal_value(self) -> LiteralValue; } impl Literal for String { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Utf8(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Utf8(self) } } impl<'a> Literal for &'a str { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Utf8(self.to_owned())).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Utf8(self.to_owned()) } } macro_rules! make_literal { ($TYPE:ty, $SCALAR:ident) => { impl Literal for $TYPE { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::$SCALAR(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::$SCALAR(self) } } }; } impl<'a> Literal for &'a [u8] { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Binary(self.to_vec())).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Binary(self.to_vec()) } } impl Literal for Series { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Series(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Series(self) } } #[cfg(feature = "python")] impl Literal for pyo3::PyObject { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Python(PyObjectWrapper(self))).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Python(PyObjectWrapper(self)) + } +} + +impl Literal for Option +where + T: Literal, +{ + fn literal_value(self) -> LiteralValue { + match self { + Some(val) => val.literal_value(), + None => LiteralValue::Null, + } } } @@ -361,6 +404,10 @@ pub fn lit(t: L) -> ExprRef { t.lit() } +pub fn literal_value(t: L) -> LiteralValue { + t.literal_value() +} + pub fn null_lit() -> ExprRef { Arc::new(Expr::Literal(LiteralValue::Null)) } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a4a54e74c9..d67e522ec0 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -222,7 +222,9 @@ pub fn stateful_udf( /// Extracts the `class PartialStatefulUDF` Python objects that are in the specified expression tree #[pyfunction] -pub fn extract_partial_stateful_udf_py(expr: PyExpr) -> HashMap> { +pub fn extract_partial_stateful_udf_py( + expr: PyExpr, +) -> HashMap, Option>)> { use crate::functions::python::extract_partial_stateful_udf_py; extract_partial_stateful_udf_py(expr.expr) } diff --git a/src/daft-functions/src/count_matches.rs b/src/daft-functions/src/count_matches.rs index 89df9274a9..a5b5596681 100644 --- a/src/daft-functions/src/count_matches.rs +++ b/src/daft-functions/src/count_matches.rs @@ -7,9 +7,9 @@ use daft_dsl::{ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -struct CountMatchesFunction { - pub(super) whole_words: bool, - pub(super) case_sensitive: bool, +pub struct CountMatchesFunction { + pub whole_words: bool, + pub case_sensitive: bool, } #[typetag::serde] diff --git a/src/daft-functions/src/minhash.rs b/src/daft-functions/src/minhash.rs index 48d13e0a65..6c000c4a1a 100644 --- a/src/daft-functions/src/minhash.rs +++ b/src/daft-functions/src/minhash.rs @@ -7,10 +7,10 @@ use daft_dsl::{ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct MinHashFunction { - num_hashes: usize, - ngram_size: usize, - seed: u32, +pub struct MinHashFunction { + pub num_hashes: usize, + pub ngram_size: usize, + pub seed: u32, } #[typetag::serde] diff --git a/src/daft-functions/src/tokenize/decode.rs b/src/daft-functions/src/tokenize/decode.rs index 30a713f993..e486f274e8 100644 --- a/src/daft-functions/src/tokenize/decode.rs +++ b/src/daft-functions/src/tokenize/decode.rs @@ -66,11 +66,11 @@ fn tokenize_decode_series( } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct TokenizeDecodeFunction { - pub(super) tokens_path: String, - pub(super) io_config: Option>, - pub(super) pattern: Option, - pub(super) special_tokens: Option, +pub struct TokenizeDecodeFunction { + pub tokens_path: String, + pub io_config: Option>, + pub pattern: Option, + pub special_tokens: Option, } #[typetag::serde] diff --git a/src/daft-functions/src/tokenize/encode.rs b/src/daft-functions/src/tokenize/encode.rs index e36f9be4d2..a101cf930f 100644 --- a/src/daft-functions/src/tokenize/encode.rs +++ b/src/daft-functions/src/tokenize/encode.rs @@ -70,12 +70,12 @@ fn tokenize_encode_series( } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct TokenizeEncodeFunction { - pub(super) tokens_path: String, - pub(super) io_config: Option>, - pub(super) pattern: Option, - pub(super) special_tokens: Option, - pub(super) use_special_tokens: bool, +pub struct TokenizeEncodeFunction { + pub tokens_path: String, + pub io_config: Option>, + pub pattern: Option, + pub special_tokens: Option, + pub use_special_tokens: bool, } #[typetag::serde] diff --git a/src/daft-functions/src/tokenize/mod.rs b/src/daft-functions/src/tokenize/mod.rs index 8eb1aee7e1..564ca79226 100644 --- a/src/daft-functions/src/tokenize/mod.rs +++ b/src/daft-functions/src/tokenize/mod.rs @@ -1,7 +1,7 @@ use daft_dsl::{functions::ScalarFunction, ExprRef}; use daft_io::IOConfig; -use decode::TokenizeDecodeFunction; -use encode::TokenizeEncodeFunction; +pub use decode::TokenizeDecodeFunction; +pub use encode::TokenizeEncodeFunction; mod bpe; mod decode; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 1356689a08..2968c2b04b 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -63,6 +63,9 @@ pub enum Error { OneShotRecvError { source: tokio::sync::oneshot::error::RecvError, }, + #[cfg(feature = "python")] + #[snafu(display("PyIOError: {}", source))] + PyIO { source: PyErr }, #[snafu(display("Error creating pipeline from {}: {}", plan_name, source))] PipelineCreationError { source: DaftError, diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 7d36ba6a22..5b9f95d96e 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -137,162 +137,140 @@ async fn stream_scan_task( } let source = scan_task.sources.first().unwrap(); let url = source.get_path(); - let table_stream = match scan_task.storage_config.as_ref() { - StorageConfig::Native(native_storage_config) => { - let io_config = Arc::new( - native_storage_config - .io_config - .as_ref() - .cloned() - .unwrap_or_default(), - ); - let multi_threaded_io = native_storage_config.multithreaded_io; - let io_client = daft_io::get_io_client(multi_threaded_io, io_config)?; - - match scan_task.file_format_config.as_ref() { - // ******************** - // Native Parquet Reads - // ******************** - FileFormatConfig::Parquet(ParquetSourceConfig { - coerce_int96_timestamp_unit, - field_id_mapping, - .. - }) => { - let inference_options = - ParquetSchemaInferenceOptions::new(Some(*coerce_int96_timestamp_unit)); - - if source.get_iceberg_delete_files().is_some() { - return Err(common_error::DaftError::TypeError( - "Streaming reads not supported for Iceberg delete files".to_string(), - )); - } + let (io_config, multi_threaded_io) = match scan_task.storage_config.as_ref() { + StorageConfig::Native(native_storage_config) => ( + native_storage_config.io_config.as_ref(), + native_storage_config.multithreaded_io, + ), - let row_groups = - if let Some(ChunkSpec::Parquet(row_groups)) = source.get_chunk_spec() { - Some(row_groups.clone()) - } else { - None - }; - let metadata = scan_task - .sources - .first() - .and_then(|s| s.get_parquet_metadata().cloned()); - daft_parquet::read::stream_parquet( - url, - file_column_names.as_deref(), - None, - scan_task.pushdowns.limit, - row_groups, - scan_task.pushdowns.filters.clone(), - io_client.clone(), - io_stats, - &inference_options, - field_id_mapping.clone(), - metadata, - maintain_order, - ) - .await? - } + #[cfg(feature = "python")] + StorageConfig::Python(python_storage_config) => { + (python_storage_config.io_config.as_ref(), true) + } + }; + let io_config = Arc::new(io_config.cloned().unwrap_or_default()); + let io_client = daft_io::get_io_client(multi_threaded_io, io_config)?; + let table_stream = match scan_task.file_format_config.as_ref() { + FileFormatConfig::Parquet(ParquetSourceConfig { + coerce_int96_timestamp_unit, + field_id_mapping, + .. + }) => { + let inference_options = + ParquetSchemaInferenceOptions::new(Some(*coerce_int96_timestamp_unit)); - // **************** - // Native CSV Reads - // **************** - FileFormatConfig::Csv(cfg) => { - let schema_of_file = scan_task.schema.clone(); - let col_names = if !cfg.has_headers { - Some( - schema_of_file - .fields - .values() - .map(|f| f.name.as_str()) - .collect::>(), - ) - } else { - None - }; - let convert_options = CsvConvertOptions::new_internal( - scan_task.pushdowns.limit, - file_column_names - .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), - col_names - .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), - Some(schema_of_file), - scan_task.pushdowns.filters.clone(), - ); - let parse_options = CsvParseOptions::new_with_defaults( - cfg.has_headers, - cfg.delimiter, - cfg.double_quote, - cfg.quote, - cfg.allow_variable_columns, - cfg.escape_char, - cfg.comment, - )?; - let read_options = - CsvReadOptions::new_internal(cfg.buffer_size, cfg.chunk_size); - daft_csv::stream_csv( - url.to_string(), - Some(convert_options), - Some(parse_options), - Some(read_options), - io_client.clone(), - io_stats.clone(), - None, - // maintain_order, TODO: Implement maintain_order for CSV - ) - .await? - } + if source.get_iceberg_delete_files().is_some() { + return Err(common_error::DaftError::TypeError( + "Streaming reads not supported for Iceberg delete files".to_string(), + )); + } - // **************** - // Native JSON Reads - // **************** - FileFormatConfig::Json(cfg) => { - let schema_of_file = scan_task.schema.clone(); - let convert_options = JsonConvertOptions::new_internal( - scan_task.pushdowns.limit, - file_column_names - .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), - Some(schema_of_file), - scan_task.pushdowns.filters.clone(), - ); - // let - let parse_options = JsonParseOptions::new_internal(); - let read_options = - JsonReadOptions::new_internal(cfg.buffer_size, cfg.chunk_size); + let row_groups = if let Some(ChunkSpec::Parquet(row_groups)) = source.get_chunk_spec() { + Some(row_groups.clone()) + } else { + None + }; + let metadata = scan_task + .sources + .first() + .and_then(|s| s.get_parquet_metadata().cloned()); + daft_parquet::read::stream_parquet( + url, + file_column_names.as_deref(), + None, + scan_task.pushdowns.limit, + row_groups, + scan_task.pushdowns.filters.clone(), + io_client, + io_stats, + &inference_options, + field_id_mapping.clone(), + metadata, + maintain_order, + ) + .await? + } + FileFormatConfig::Csv(cfg) => { + let schema_of_file = scan_task.schema.clone(); + let col_names = if !cfg.has_headers { + Some( + schema_of_file + .fields + .values() + .map(|f| f.name.as_str()) + .collect::>(), + ) + } else { + None + }; + let convert_options = CsvConvertOptions::new_internal( + scan_task.pushdowns.limit, + file_column_names + .as_ref() + .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + col_names + .as_ref() + .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + Some(schema_of_file), + scan_task.pushdowns.filters.clone(), + ); + let parse_options = CsvParseOptions::new_with_defaults( + cfg.has_headers, + cfg.delimiter, + cfg.double_quote, + cfg.quote, + cfg.allow_variable_columns, + cfg.escape_char, + cfg.comment, + )?; + let read_options = CsvReadOptions::new_internal(cfg.buffer_size, cfg.chunk_size); + daft_csv::stream_csv( + url.to_string(), + Some(convert_options), + Some(parse_options), + Some(read_options), + io_client, + io_stats.clone(), + None, + // maintain_order, TODO: Implement maintain_order for CSV + ) + .await? + } + FileFormatConfig::Json(cfg) => { + let schema_of_file = scan_task.schema.clone(); + let convert_options = JsonConvertOptions::new_internal( + scan_task.pushdowns.limit, + file_column_names + .as_ref() + .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + Some(schema_of_file), + scan_task.pushdowns.filters.clone(), + ); + let parse_options = JsonParseOptions::new_internal(); + let read_options = JsonReadOptions::new_internal(cfg.buffer_size, cfg.chunk_size); - daft_json::read::stream_json( - url.to_string(), - Some(convert_options), - Some(parse_options), - Some(read_options), - io_client, - io_stats, - None, - // maintain_order, TODO: Implement maintain_order for JSON - ) - .await? - } - #[cfg(feature = "python")] - FileFormatConfig::Database(_) => { - return Err(common_error::DaftError::TypeError( - "Native reads for Database file format not implemented".to_string(), - )); - } - #[cfg(feature = "python")] - FileFormatConfig::PythonFunction => { - return Err(common_error::DaftError::TypeError( - "Native reads for PythonFunction file format not implemented".to_string(), - )); - } - } + daft_json::read::stream_json( + url.to_string(), + Some(convert_options), + Some(parse_options), + Some(read_options), + io_client, + io_stats, + None, + // maintain_order, TODO: Implement maintain_order for JSON + ) + .await? + } + #[cfg(feature = "python")] + FileFormatConfig::Database(_) => { + return Err(common_error::DaftError::TypeError( + "Database file format not implemented".to_string(), + )); } #[cfg(feature = "python")] - StorageConfig::Python(_) => { + FileFormatConfig::PythonFunction => { return Err(common_error::DaftError::TypeError( - "Streaming reads not supported for Python storage config".to_string(), + "PythonFunction file format not implemented".to_string(), )); } }; diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 2060b3de30..810cadcba2 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -749,7 +749,7 @@ impl PyMicroPartition { } } -pub(crate) fn read_json_into_py_table( +pub fn read_json_into_py_table( py: Python, uri: &str, schema: PySchema, @@ -776,7 +776,7 @@ pub(crate) fn read_json_into_py_table( } #[allow(clippy::too_many_arguments)] -pub(crate) fn read_csv_into_py_table( +pub fn read_csv_into_py_table( py: Python, uri: &str, has_header: bool, @@ -810,7 +810,7 @@ pub(crate) fn read_csv_into_py_table( .extract() } -pub(crate) fn read_parquet_into_py_table( +pub fn read_parquet_into_py_table( py: Python, uri: &str, schema: PySchema, diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 2bcf080f0b..a3b36d4a34 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -502,17 +502,29 @@ impl ParquetFileReader { .into_iter() .collect::>>()?; - let table_iter = arrow_column_iters_to_table_iter( - arr_iters, - row_range.start, - daft_schema, - uri, - predicate, - original_columns, - original_num_rows, - ); rayon::spawn(move || { - for table_result in table_iter { + // Even if there are no columns to read, we still need to create a empty table with the correct number of rows + // This is because the columns may be present in other files. See https://github.com/Eventual-Inc/Daft/pull/2514 + let table_iter = arrow_column_iters_to_table_iter( + arr_iters, + row_range.start, + daft_schema.clone(), + uri, + predicate, + original_columns, + original_num_rows, + ); + if table_iter.is_none() { + let table = + Table::new_with_size(daft_schema, vec![], row_range.num_rows); + if let Err(crossbeam_channel::TrySendError::Full(_)) = + sender.try_send(table) + { + panic!("Parquet stream channel should not be full") + } + return; + } + for table_result in table_iter.unwrap() { let is_err = table_result.is_err(); if let Err(crossbeam_channel::TrySendError::Full(_)) = sender.try_send(table_result) diff --git a/src/daft-parquet/src/stream_reader.rs b/src/daft-parquet/src/stream_reader.rs index 178141c64d..1e8c3f9d27 100644 --- a/src/daft-parquet/src/stream_reader.rs +++ b/src/daft-parquet/src/stream_reader.rs @@ -56,7 +56,10 @@ pub(crate) fn arrow_column_iters_to_table_iter( predicate: Option, original_columns: Option>, original_num_rows: Option, -) -> impl Iterator> { +) -> Option>> { + if arr_iters.is_empty() { + return None; + } pub struct ParallelLockStepIter { pub iters: ArrowChunkIters, } @@ -73,7 +76,7 @@ pub(crate) fn arrow_column_iters_to_table_iter( // and slice arrays that are partially needed. let mut index_so_far = 0; let owned_schema_ref = schema_ref.clone(); - par_lock_step_iter.into_iter().map(move |chunk| { + let table_iter = par_lock_step_iter.into_iter().map(move |chunk| { let chunk = chunk.with_context(|_| { super::UnableToCreateChunkFromStreamingFileReaderSnafu { path: uri.clone() } })?; @@ -96,7 +99,10 @@ pub(crate) fn arrow_column_iters_to_table_iter( }) .collect::>>()?; - let len = all_series[0].len(); + let len = all_series + .first() + .map(|s| s.len()) + .expect("All series should not be empty when creating table from parquet chunks"); if all_series.iter().any(|s| s.len() != len) { return Err(super::Error::ParquetColumnsDontHaveEqualRows { path: uri.clone() }.into()); } @@ -115,7 +121,8 @@ pub(crate) fn arrow_column_iters_to_table_iter( } } Ok(table) - }) + }); + Some(table_iter) } struct CountingReader { @@ -524,36 +531,41 @@ pub(crate) fn local_parquet_stream( .unzip(); let owned_uri = uri.to_string(); - let table_iters = - column_iters - .into_iter() - .zip(row_ranges) - .map(move |(rg_col_iter_result, rg_range)| { - let rg_col_iter = rg_col_iter_result?; - let table_iter = arrow_column_iters_to_table_iter( - rg_col_iter, - rg_range.start, - schema_ref.clone(), - owned_uri.clone(), - predicate.clone(), - original_columns.clone(), - original_num_rows, - ); - DaftResult::Ok(table_iter) - }); rayon::spawn(move || { // Once a row group has been read into memory and we have the column iterators, // we can start processing them in parallel. - let par_table_iters = table_iters.zip(senders).par_bridge(); + let par_column_iters = column_iters.zip(row_ranges).zip(senders).par_bridge(); // For each vec of column iters, iterate through them in parallel lock step such that each iteration // produces a chunk of the row group that can be converted into a table. - par_table_iters.for_each(move |(table_iter_result, tx)| { - let table_iter = match table_iter_result { - Ok(t) => t, + par_column_iters.for_each(move |((rg_column_iters_result, rg_range), tx)| { + let table_iter = match rg_column_iters_result { + Ok(rg_column_iters) => { + let table_iter = arrow_column_iters_to_table_iter( + rg_column_iters, + rg_range.start, + schema_ref.clone(), + owned_uri.clone(), + predicate.clone(), + original_columns.clone(), + original_num_rows, + ); + // Even if there are no columns to read, we still need to create a empty table with the correct number of rows + // This is because the columns may be present in other files. See https://github.com/Eventual-Inc/Daft/pull/2514 + if let Some(table_iter) = table_iter { + table_iter + } else { + let table = + Table::new_with_size(schema_ref.clone(), vec![], rg_range.num_rows); + if let Err(crossbeam_channel::TrySendError::Full(_)) = tx.try_send(table) { + panic!("Parquet stream channel should not be full") + } + return; + } + } Err(e) => { - let _ = tx.send(Err(e)); + let _ = tx.send(Err(e.into())); return; } }; diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 7de191fe81..13d2c4307f 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -34,6 +34,7 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} snafu = {workspace = true} +uuid = {version = "1", features = ["v4"]} [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 982a3634a9..a9e05ec6cb 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,17 +1,27 @@ use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, sync::Arc, }; use common_daft_config::DaftPlanningConfig; use common_display::mermaid::MermaidDisplayOptions; use common_error::DaftResult; -use common_file_formats::FileFormat; +use common_file_formats::{FileFormat, FileFormatConfig, ParquetSourceConfig}; use common_io_config::IOConfig; -use daft_core::join::{JoinStrategy, JoinType}; +use daft_core::{ + join::{JoinStrategy, JoinType}, + prelude::TimeUnit, +}; use daft_dsl::{col, ExprRef}; -use daft_scan::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; -use daft_schema::schema::{Schema, SchemaRef}; +use daft_scan::{ + glob::GlobScanOperator, + storage_config::{NativeStorageConfig, StorageConfig}, + PhysicalScanInfo, Pushdowns, ScanOperatorRef, +}; +use daft_schema::{ + field::Field, + schema::{Schema, SchemaRef}, +}; #[cfg(feature = "python")] use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, @@ -73,7 +83,29 @@ impl From<&LogicalPlanBuilder> for LogicalPlanRef { value.plan.clone() } } - +pub trait IntoGlobPath { + fn into_glob_path(self) -> Vec; +} +impl IntoGlobPath for Vec { + fn into_glob_path(self) -> Vec { + self + } +} +impl IntoGlobPath for String { + fn into_glob_path(self) -> Vec { + vec![self] + } +} +impl IntoGlobPath for &str { + fn into_glob_path(self) -> Vec { + vec![self.to_string()] + } +} +impl IntoGlobPath for Vec<&str> { + fn into_glob_path(self) -> Vec { + self.iter().map(|s| s.to_string()).collect() + } +} impl LogicalPlanBuilder { /// Replace the LogicalPlanBuilder's plan with the provided plan pub fn with_new_plan>>(&self, plan: LP) -> Self { @@ -105,9 +137,51 @@ impl LogicalPlanBuilder { )); let logical_plan: LogicalPlan = logical_ops::Source::new(schema.clone(), source_info.into()).into(); + Ok(Self::new(logical_plan.into(), None)) } + #[cfg(feature = "python")] + pub fn delta_scan>( + glob_path: T, + io_config: Option, + multithreaded_io: bool, + ) -> DaftResult { + use daft_scan::storage_config::PyStorageConfig; + + Python::with_gil(|py| { + let io_config = io_config.unwrap_or_default(); + + let native_storage_config = NativeStorageConfig { + io_config: Some(io_config), + multithreaded_io, + }; + + let py_storage_config: PyStorageConfig = + Arc::new(StorageConfig::Native(Arc::new(native_storage_config))).into(); + + // let py_io_config = PyIOConfig { config: io_config }; + let delta_lake_scan = PyModule::import_bound(py, "daft.delta_lake.delta_lake_scan")?; + let delta_lake_scan_operator = + delta_lake_scan.getattr(pyo3::intern!(py, "DeltaLakeScanOperator"))?; + let delta_lake_operator = delta_lake_scan_operator + .call1((glob_path.as_ref(), py_storage_config))? + .to_object(py); + let scan_operator_handle = + ScanOperatorHandle::from_python_scan_operator(delta_lake_operator, py)?; + Self::table_scan(scan_operator_handle.into(), None) + }) + } + + #[cfg(not(feature = "python"))] + pub fn delta_scan( + glob_path: T, + io_config: Option, + multithreaded_io: bool, + ) -> DaftResult { + panic!("Delta Lake scan requires the 'python' feature to be enabled.") + } + pub fn table_scan( scan_operator: ScanOperatorRef, pushdowns: Option, @@ -142,6 +216,10 @@ impl LogicalPlanBuilder { Ok(Self::new(logical_plan.into(), None)) } + pub fn parquet_scan(glob_path: T) -> ParquetScanBuilder { + ParquetScanBuilder::new(glob_path) + } + pub fn select(&self, to_select: Vec) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), to_select)?.into(); @@ -498,6 +576,95 @@ impl LogicalPlanBuilder { } } +pub struct ParquetScanBuilder { + pub glob_paths: Vec, + pub infer_schema: bool, + pub coerce_int96_timestamp_unit: TimeUnit, + pub field_id_mapping: Option>>, + pub row_groups: Option>>>, + pub chunk_size: Option, + pub io_config: Option, + pub multithreaded: bool, + pub schema: Option, +} + +impl ParquetScanBuilder { + pub fn new(glob_paths: T) -> Self { + let glob_paths = glob_paths.into_glob_path(); + Self::new_impl(glob_paths) + } + + // concrete implementation to reduce LLVM code duplication + fn new_impl(glob_paths: Vec) -> Self { + Self { + glob_paths, + infer_schema: true, + coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, + field_id_mapping: None, + row_groups: None, + chunk_size: None, + multithreaded: true, + schema: None, + io_config: None, + } + } + pub fn infer_schema(mut self, infer_schema: bool) -> Self { + self.infer_schema = infer_schema; + self + } + pub fn coerce_int96_timestamp_unit(mut self, unit: TimeUnit) -> Self { + self.coerce_int96_timestamp_unit = unit; + self + } + pub fn field_id_mapping(mut self, field_id_mapping: Arc>) -> Self { + self.field_id_mapping = Some(field_id_mapping); + self + } + pub fn row_groups(mut self, row_groups: Vec>>) -> Self { + self.row_groups = Some(row_groups); + self + } + pub fn chunk_size(mut self, chunk_size: usize) -> Self { + self.chunk_size = Some(chunk_size); + self + } + + pub fn io_config(mut self, io_config: IOConfig) -> Self { + self.io_config = Some(io_config); + self + } + + pub fn multithreaded(mut self, multithreaded: bool) -> Self { + self.multithreaded = multithreaded; + self + } + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + pub fn finish(self) -> DaftResult { + let cfg = ParquetSourceConfig { + coerce_int96_timestamp_unit: self.coerce_int96_timestamp_unit, + field_id_mapping: self.field_id_mapping, + row_groups: self.row_groups, + chunk_size: self.chunk_size, + }; + + let operator = Arc::new(GlobScanOperator::try_new( + self.glob_paths, + Arc::new(FileFormatConfig::Parquet(cfg)), + Arc::new(StorageConfig::Native(Arc::new( + NativeStorageConfig::new_internal(self.multithreaded, self.io_config), + ))), + self.infer_schema, + self.schema, + )?); + + LogicalPlanBuilder::table_scan(ScanOperatorRef(operator), None) + } +} + /// A Python-facing wrapper of the LogicalPlanBuilder. /// /// This lightweight proxy interface should hold as much of the Python-specific logic diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 50e309916b..2541a143db 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -19,7 +19,7 @@ pub mod source_info; mod test; mod treenode; -pub use builder::{LogicalPlanBuilder, PyLogicalPlanBuilder}; +pub use builder::{LogicalPlanBuilder, ParquetScanBuilder, PyLogicalPlanBuilder}; pub use daft_core::join::{JoinStrategy, JoinType}; pub use logical_plan::{LogicalPlan, LogicalPlanRef}; pub use partitioning::ClusteringSpec; diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 2a68390066..8e6d0b005e 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use common_error::DaftError; +use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; use daft_dsl::{ col, @@ -13,6 +13,7 @@ use daft_dsl::{ }; use itertools::Itertools; use snafu::ResultExt; +use uuid::Uuid; use crate::{ logical_ops::Project, @@ -54,14 +55,31 @@ impl Join { join_type: JoinType, join_strategy: Option, ) -> logical_plan::Result { - let (left_on, left_fields) = - resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; - let (right_on, right_fields) = + let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; + let (right_on, _) = resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?; - for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] { - let on_schema = Schema::new(on_fields).context(CreationSnafu)?; - for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) { + let (unique_left_on, unique_right_on) = + Self::rename_join_keys(left_on.clone(), right_on.clone()); + + let left_fields: Vec = unique_left_on + .iter() + .map(|e| e.to_field(&left.schema())) + .collect::>>() + .context(CreationSnafu)?; + + let right_fields: Vec = unique_right_on + .iter() + .map(|e| e.to_field(&right.schema())) + .collect::>>() + .context(CreationSnafu)?; + + for (on_exprs, on_fields) in [ + (&unique_left_on, &left_fields), + (&unique_right_on, &right_fields), + ] { + for (field, expr) in on_fields.iter().zip(on_exprs.iter()) { + // Null type check for both fields and expressions if matches!(field.dtype, DataType::Null) { return Err(DaftError::ValueError(format!( "Can't join on null type expressions: {expr}" @@ -167,6 +185,60 @@ impl Join { } } + /// Renames join keys for the given left and right expressions. This is required to + /// prevent errors when the join keys on the left and right expressions have the same key + /// name. + /// + /// This function takes two vectors of expressions (`left_exprs` and `right_exprs`) and + /// checks for pairs of column expressions that differ. If both expressions in a pair + /// are column expressions and they are not identical, it generates a unique identifier + /// and renames both expressions by appending this identifier to their original names. + /// + /// The function returns two vectors of expressions, where the renamed expressions are + /// substituted for the original expressions in the cases where renaming occurred. + /// + /// # Parameters + /// - `left_exprs`: A vector of expressions from the left side of a join. + /// - `right_exprs`: A vector of expressions from the right side of a join. + /// + /// # Returns + /// A tuple containing two vectors of expressions, one for the left side and one for the + /// right side, where expressions that needed to be renamed have been modified. + /// + /// # Example + /// ``` + /// let (renamed_left, renamed_right) = rename_join_keys(left_expressions, right_expressions); + /// ``` + /// + /// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649). + + fn rename_join_keys( + left_exprs: Vec>, + right_exprs: Vec>, + ) -> (Vec>, Vec>) { + left_exprs + .into_iter() + .zip(right_exprs) + .map( + |(left_expr, right_expr)| match (&*left_expr, &*right_expr) { + (Expr::Column(left_name), Expr::Column(right_name)) + if left_name == right_name => + { + (left_expr, right_expr) + } + _ => { + let unique_id = Uuid::new_v4().to_string(); + let renamed_left_expr = + left_expr.alias(format!("{}_{}", left_expr.name(), unique_id)); + let renamed_right_expr = + right_expr.alias(format!("{}_{}", right_expr.name(), unique_id)); + (renamed_left_expr, renamed_right_expr) + } + }, + ) + .unzip() + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Join: Type = {}", self.join_type)); diff --git a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs index 1d77502221..a44674db55 100644 --- a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs +++ b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs @@ -122,7 +122,7 @@ impl SplitActorPoolProjects { impl OptimizerRule for SplitActorPoolProjects { fn try_optimize(&self, plan: Arc) -> DaftResult>> { plan.transform_down(|node| match node.as_ref() { - LogicalPlan::Project(projection) => try_optimize_project(projection, node.clone(), 0), + LogicalPlan::Project(projection) => try_optimize_project(projection, node.clone()), _ => Ok(Transformed::no(node)), }) } @@ -370,8 +370,34 @@ fn split_projection( fn try_optimize_project( projection: &Project, plan: Arc, +) -> DaftResult>> { + // Add aliases to the expressions in the projection to preserve original names when splitting stateful UDFs. + // This is needed because when we split stateful UDFs, we create new names for intermediates, but we would like + // to have the same expression names as the original projection. + let aliased_projection_exprs = projection + .projection + .iter() + .map(|e| { + if has_stateful_udf(e) && !matches!(e.as_ref(), Expr::Alias(..)) { + e.alias(e.name()) + } else { + e.clone() + } + }) + .collect(); + + let aliased_projection = Project::try_new(projection.input.clone(), aliased_projection_exprs)?; + + recursive_optimize_project(&aliased_projection, plan, 0) +} + +fn recursive_optimize_project( + projection: &Project, + plan: Arc, recursive_count: usize, ) -> DaftResult>> { + // TODO: eliminate the need for recursive calls by doing a post-order traversal of the plan tree. + // Base case: no stateful UDFs at all let has_stateful_udfs = projection.projection.iter().any(has_stateful_udf); if !has_stateful_udfs { @@ -416,8 +442,11 @@ fn try_optimize_project( // Recursively run the rule on the new child Project let new_project = Project::try_new(projection.input.clone(), remaining)?; let new_child_project = LogicalPlan::Project(new_project.clone()).arced(); - let optimized_child_plan = - try_optimize_project(&new_project, new_child_project.clone(), recursive_count + 1)?; + let optimized_child_plan = recursive_optimize_project( + &new_project, + new_child_project.clone(), + recursive_count + 1, + )?; optimized_child_plan.data.clone() }; @@ -785,6 +814,67 @@ mod tests { Ok(()) } + #[test] + fn test_multiple_with_column_serial_no_alias() -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]); + let scan_plan = dummy_scan_node(scan_op); + let stacked_stateful_project_expr = + create_stateful_udf(vec![create_stateful_udf(vec![col("a")])]); + + // Add a Projection with StatefulUDF and resource request + let project_plan = scan_plan + .select(vec![stacked_stateful_project_expr.clone()])? + .build(); + + let intermediate_name = "__TruncateRootStatefulUDF_0-0-0__"; + + let expected = scan_plan.select(vec![col("a")])?.build(); + let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + expected, + vec![ + col("a"), + create_stateful_udf(vec![col("a")]) + .clone() + .alias(intermediate_name), + ], + )?) + .arced(); + let expected = + LogicalPlan::Project(Project::try_new(expected, vec![col(intermediate_name)])?).arced(); + let expected = + LogicalPlan::Project(Project::try_new(expected, vec![col(intermediate_name)])?).arced(); + let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + expected, + vec![ + col(intermediate_name), + create_stateful_udf(vec![col(intermediate_name)]) + .clone() + .alias("a"), + ], + )?) + .arced(); + let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("a")])?).arced(); + assert_optimized_plan_eq(project_plan.clone(), expected.clone())?; + + let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + scan_plan.build(), + vec![create_stateful_udf(vec![col("a")]) + .clone() + .alias(intermediate_name)], + )?) + .arced(); + let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + expected, + vec![create_stateful_udf(vec![col(intermediate_name)]) + .clone() + .alias("a")], + )?) + .arced(); + assert_optimized_plan_eq_with_projection_pushdown(project_plan.clone(), expected.clone())?; + + Ok(()) + } + #[test] fn test_multiple_with_column_serial_multiarg() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![ diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 10cc0c6804..23191b1d11 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -22,7 +22,7 @@ use serde::{Deserialize, Serialize}; mod anonymous; pub use anonymous::AnonymousScanOperator; -mod glob; +pub mod glob; use common_daft_config::DaftExecutionConfig; pub mod scan_task_iters; diff --git a/src/daft-schema/src/time_unit.rs b/src/daft-schema/src/time_unit.rs index d4b17b0e7c..9b1afea2e5 100644 --- a/src/daft-schema/src/time_unit.rs +++ b/src/daft-schema/src/time_unit.rs @@ -1,4 +1,7 @@ +use std::str::FromStr; + use arrow2::datatypes::TimeUnit as ArrowTimeUnit; +use common_error::DaftError; use derive_more::Display; use serde::{Deserialize, Serialize}; @@ -33,6 +36,19 @@ impl TimeUnit { } } +impl FromStr for TimeUnit { + type Err = DaftError; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "ns" | "nanoseconds" => Ok(Self::Nanoseconds), + "us" | "microseconds" => Ok(Self::Microseconds), + "ms" | "milliseconds" => Ok(Self::Milliseconds), + "s" | "seconds" => Ok(Self::Seconds), + _ => Err(DaftError::ValueError("Invalid time unit".to_string())), + } + } +} + impl From<&ArrowTimeUnit> for TimeUnit { fn from(tu: &ArrowTimeUnit) -> Self { match tu { diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 2b80dda42c..81d7d36ff0 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -1,6 +1,7 @@ [dependencies] common-daft-config = {path = "../common/daft-config"} common-error = {path = "../common/error"} +common-io-config = {path = "../common/io-config", default-features = false} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} diff --git a/src/daft-sql/src/error.rs b/src/daft-sql/src/error.rs index 31f8a400ed..1fd9ae97e7 100644 --- a/src/daft-sql/src/error.rs +++ b/src/daft-sql/src/error.rs @@ -12,6 +12,8 @@ pub enum PlannerError { ParseError { message: String }, #[snafu(display("Invalid operation: {message}"))] InvalidOperation { message: String }, + #[snafu(display("Invalid argument ({message}) for function '{function}'"))] + InvalidFunctionArgument { message: String, function: String }, #[snafu(display("Table not found: {message}"))] TableNotFound { message: String }, #[snafu(display("Column {column_name} not found in {relation}"))] @@ -66,6 +68,13 @@ impl PlannerError { message: message.into(), } } + + pub fn invalid_argument, F: Into>(arg: S, function: F) -> Self { + Self::InvalidFunctionArgument { + message: arg.into(), + function: function.into(), + } + } } #[macro_export] diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index 8a82fcd80a..2a67d97c63 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -1,6 +1,8 @@ use std::{collections::HashMap, sync::Arc}; +use config::SQLModuleConfig; use daft_dsl::ExprRef; +use hashing::SQLModuleHashing; use once_cell::sync::Lazy; use sqlparser::ast::{ Function, FunctionArg, FunctionArgExpr, FunctionArgOperator, FunctionArguments, @@ -14,10 +16,11 @@ use crate::{ }; /// [SQL_FUNCTIONS] is a singleton that holds all the registered SQL functions. -static SQL_FUNCTIONS: Lazy = Lazy::new(|| { +pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { let mut functions = SQLFunctions::new(); functions.register::(); functions.register::(); + functions.register::(); functions.register::(); functions.register::(); functions.register::(); @@ -29,6 +32,7 @@ static SQL_FUNCTIONS: Lazy = Lazy::new(|| { functions.register::(); functions.register::(); functions.register::(); + functions.register::(); functions }); @@ -82,13 +86,24 @@ pub trait SQLFunction: Send + Sync { } fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult; + + /// Produce the docstrings for this SQL function, parametrized by an alias which is the function name to invoke this in SQL + fn docstrings(&self, alias: &str) -> String { + format!("{alias}: No docstring available") + } + + /// Produce the docstrings for this SQL function, parametrized by an alias which is the function name to invoke this in SQL + fn arg_names(&self) -> &'static [&'static str] { + &["todo"] + } } /// TODOs /// - Use multimap for function variants. /// - Add more functions.. pub struct SQLFunctions { - map: HashMap>, + pub(crate) map: HashMap>, + pub(crate) docsmap: HashMap, } pub(crate) struct SQLFunctionArguments { @@ -97,12 +112,77 @@ pub(crate) struct SQLFunctionArguments { } impl SQLFunctionArguments { - pub fn get_unnamed(&self, idx: usize) -> Option<&ExprRef> { + pub fn get_positional(&self, idx: usize) -> Option<&ExprRef> { self.positional.get(&idx) } pub fn get_named(&self, name: &str) -> Option<&ExprRef> { self.named.get(name) } + + pub fn try_get_named(&self, name: &str) -> Result, PlannerError> { + self.named + .get(name) + .map(|expr| T::from_expr(expr)) + .transpose() + } + pub fn try_get_positional(&self, idx: usize) -> Result, PlannerError> { + self.positional + .get(&idx) + .map(|expr| T::from_expr(expr)) + .transpose() + } +} + +pub trait SQLLiteral { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized; +} + +impl SQLLiteral for String { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + let e = expr + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| PlannerError::invalid_operation("Expected a string literal"))?; + Ok(e.to_string()) + } +} + +impl SQLLiteral for i64 { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(|lit| lit.as_i64()) + .ok_or_else(|| PlannerError::invalid_operation("Expected an integer literal")) + } +} + +impl SQLLiteral for usize { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(|lit| lit.as_i64().map(|v| v as Self)) + .ok_or_else(|| PlannerError::invalid_operation("Expected an integer literal")) + } +} + +impl SQLLiteral for bool { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(|lit| lit.as_bool()) + .ok_or_else(|| PlannerError::invalid_operation("Expected a boolean literal")) + } } impl SQLFunctions { @@ -110,6 +190,7 @@ impl SQLFunctions { pub fn new() -> Self { Self { map: HashMap::new(), + docsmap: HashMap::new(), } } @@ -120,6 +201,8 @@ impl SQLFunctions { /// Add a [FunctionExpr] to the [SQLFunctions] instance. pub fn add_fn(&mut self, name: &str, func: F) { + self.docsmap + .insert(name.to_string(), (func.docstrings(name), func.arg_names())); self.map.insert(name.to_string(), Arc::new(func)); } @@ -194,6 +277,15 @@ impl SQLPlanner { where T: TryFrom, { + self.parse_function_args(args, expected_named, expected_positional)? + .try_into() + } + pub(crate) fn parse_function_args( + &self, + args: &[FunctionArg], + expected_named: &'static [&'static str], + expected_positional: usize, + ) -> SQLPlannerResult { let mut positional_args = HashMap::new(); let mut named_args = HashMap::new(); for (idx, arg) in args.iter().enumerate() { @@ -214,15 +306,14 @@ impl SQLPlanner { } positional_args.insert(idx, self.try_unwrap_function_arg_expr(arg)?); } - _ => unsupported_sql_err!("unsupported function argument type"), + other => unsupported_sql_err!("unsupported function argument type: {other}, valid function arguments for this function are: {expected_named:?}."), } } - SQLFunctionArguments { + Ok(SQLFunctionArguments { positional: positional_args, named: named_args, - } - .try_into() + }) } pub(crate) fn plan_function_arg( @@ -235,7 +326,10 @@ impl SQLPlanner { } } - fn try_unwrap_function_arg_expr(&self, expr: &FunctionArgExpr) -> SQLPlannerResult { + pub(crate) fn try_unwrap_function_arg_expr( + &self, + expr: &FunctionArgExpr, + ) -> SQLPlannerResult { match expr { FunctionArgExpr::Expr(expr) => self.plan_expr(expr), _ => unsupported_sql_err!("Wildcard function args not yet supported"), diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 310c256e27..6246e8b242 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -3,9 +3,9 @@ pub mod error; pub mod functions; mod modules; mod planner; - #[cfg(feature = "python")] pub mod python; +mod table_provider; #[cfg(feature = "python")] use pyo3::prelude::*; @@ -15,6 +15,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_class::()?; parent.add_function(wrap_pyfunction_bound!(python::sql, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::sql_expr, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(python::list_sql_functions, parent)?)?; Ok(()) } @@ -261,6 +262,7 @@ mod tests { JoinType::Inner, None, )? + .select(vec![col("*")])? .build(); assert_eq!(plan, expected); Ok(()) @@ -293,7 +295,7 @@ mod tests { #[case::starts_with("select starts_with(utf8, 'a') as starts_with from tbl1")] #[case::contains("select contains(utf8, 'a') as contains from tbl1")] #[case::split("select split(utf8, '.') as split from tbl1")] - #[case::replace("select replace(utf8, 'a', 'b') as replace from tbl1")] + #[case::replace("select regexp_replace(utf8, 'a', 'b') as replace from tbl1")] #[case::length("select length(utf8) as length from tbl1")] #[case::lower("select lower(utf8) as lower from tbl1")] #[case::upper("select upper(utf8) as upper from tbl1")] diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index 695d3c9c79..0fbd2f7067 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -41,6 +41,26 @@ impl SQLFunction for AggExpr { to_expr(self, inputs.as_slice()) } } + + fn docstrings(&self, alias: &str) -> String { + match self { + Self::Count(_, _) => static_docs::COUNT_DOCSTRING.to_string(), + Self::Sum(_) => static_docs::SUM_DOCSTRING.to_string(), + Self::Mean(_) => static_docs::AVG_DOCSTRING.replace("{}", alias), + Self::Min(_) => static_docs::MIN_DOCSTRING.to_string(), + Self::Max(_) => static_docs::MAX_DOCSTRING.to_string(), + e => unimplemented!("Need to implement docstrings for {e}"), + } + } + + fn arg_names(&self) -> &'static [&'static str] { + match self { + Self::Count(_, _) | Self::Sum(_) | Self::Mean(_) | Self::Min(_) | Self::Max(_) => { + &["input"] + } + e => unimplemented!("Need to implement arg names for {e}"), + } + } } fn handle_count(inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { @@ -103,3 +123,201 @@ pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult unsupported_sql_err!("map_groups"), } } + +mod static_docs { + pub(crate) const COUNT_DOCSTRING: &str = + "Counts the number of non-null elements in the input expression. + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT count(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 1 │ + ╰───────╯ + (Showing first 1 of 1 rows)"; + + pub(crate) const SUM_DOCSTRING: &str = + "Calculates the sum of non-null elements in the input expression. + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT sum(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 300 │ + ╰───────╯ + (Showing first 1 of 1 rows)"; + + pub(crate) const AVG_DOCSTRING: &str = + "Calculates the average (mean) of non-null elements in the input expression. + +.. seealso:: + This SQL Function has aliases. + + * :func:`~daft.sql._sql_funcs.mean` + * :func:`~daft.sql._sql_funcs.avg` + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT {}(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────────╮ + │ x │ + │ --- │ + │ Float64 │ + ╞═══════════╡ + │ 150.0 │ + ╰───────────╯ + (Showing first 1 of 1 rows)"; + + pub(crate) const MIN_DOCSTRING: &str = + "Finds the minimum value among non-null elements in the input expression. + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT min(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ╰───────╯ + (Showing first 1 of 1 rows)"; + + pub(crate) const MAX_DOCSTRING: &str = + "Finds the maximum value among non-null elements in the input expression. + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT max(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 200 │ + ╰───────╯ + (Showing first 1 of 1 rows)"; +} diff --git a/src/daft-sql/src/modules/config.rs b/src/daft-sql/src/modules/config.rs new file mode 100644 index 0000000000..9a540d3025 --- /dev/null +++ b/src/daft-sql/src/modules/config.rs @@ -0,0 +1,391 @@ +use common_io_config::{AzureConfig, GCSConfig, HTTPConfig, IOConfig, S3Config}; +use daft_core::prelude::{DataType, Field}; +use daft_dsl::{literal_value, Expr, ExprRef, LiteralValue}; + +use super::SQLModule; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, + unsupported_sql_err, +}; + +pub struct SQLModuleConfig; + +impl SQLModule for SQLModuleConfig { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("S3Config", S3ConfigFunction); + parent.add_fn("HTTPConfig", HTTPConfigFunction); + parent.add_fn("AzureConfig", AzureConfigFunction); + parent.add_fn("GCSConfig", GCSConfigFunction); + } +} + +pub struct S3ConfigFunction; +macro_rules! item { + ($name:expr, $ty:ident) => { + ( + Field::new(stringify!($name), DataType::$ty), + literal_value($name), + ) + }; +} + +impl SQLFunction for S3ConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + // TODO(cory): Ideally we should use serde to deserialize the input arguments + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &[ + "region_name", + "endpoint_url", + "key_id", + "session_token", + "access_key", + "credentials_provider", + "buffer_time", + "max_connections_per_io_thread", + "retry_initial_backoff_ms", + "connect_timeout_ms", + "read_timeout_ms", + "num_tries", + "retry_mode", + "anonymous", + "use_ssl", + "verify_ssl", + "check_hostname_ssl", + "requester_pays", + "force_virtual_addressing", + "profile_name", + ], + 0, + )?; + + let region_name = args.try_get_named::("region_name")?; + let endpoint_url = args.try_get_named::("endpoint_url")?; + let key_id = args.try_get_named::("key_id")?; + let session_token = args.try_get_named::("session_token")?; + + let access_key = args.try_get_named::("access_key")?; + let buffer_time = args.try_get_named("buffer_time")?.map(|t: i64| t as u64); + + let max_connections_per_io_thread = args + .try_get_named("max_connections_per_io_thread")? + .map(|t: i64| t as u32); + + let retry_initial_backoff_ms = args + .try_get_named("retry_initial_backoff_ms")? + .map(|t: i64| t as u64); + + let connect_timeout_ms = args + .try_get_named("connect_timeout_ms")? + .map(|t: i64| t as u64); + + let read_timeout_ms = args + .try_get_named("read_timeout_ms")? + .map(|t: i64| t as u64); + + let num_tries = args.try_get_named("num_tries")?.map(|t: i64| t as u32); + let retry_mode = args.try_get_named::("retry_mode")?; + let anonymous = args.try_get_named::("anonymous")?; + let use_ssl = args.try_get_named::("use_ssl")?; + let verify_ssl = args.try_get_named::("verify_ssl")?; + let check_hostname_ssl = args.try_get_named::("check_hostname_ssl")?; + let requester_pays = args.try_get_named::("requester_pays")?; + let force_virtual_addressing = args.try_get_named::("force_virtual_addressing")?; + let profile_name = args.try_get_named::("profile_name")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("s3")), + item!(region_name, Utf8), + item!(endpoint_url, Utf8), + item!(key_id, Utf8), + item!(session_token, Utf8), + item!(access_key, Utf8), + item!(buffer_time, UInt64), + item!(max_connections_per_io_thread, UInt32), + item!(retry_initial_backoff_ms, UInt64), + item!(connect_timeout_ms, UInt64), + item!(read_timeout_ms, UInt64), + item!(num_tries, UInt32), + item!(retry_mode, Utf8), + item!(anonymous, Boolean), + item!(use_ssl, Boolean), + item!(verify_ssl, Boolean), + item!(check_hostname_ssl, Boolean), + item!(requester_pays, Boolean), + item!(force_virtual_addressing, Boolean), + item!(profile_name, Utf8), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub struct HTTPConfigFunction; + +impl SQLFunction for HTTPConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + let args: SQLFunctionArguments = + planner.parse_function_args(inputs, &["user_agent", "bearer_token"], 0)?; + + let user_agent = args.try_get_named::("user_agent")?; + let bearer_token = args.try_get_named::("bearer_token")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("http")), + item!(user_agent, Utf8), + item!(bearer_token, Utf8), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} +pub struct AzureConfigFunction; +impl SQLFunction for AzureConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &[ + "storage_account", + "access_key", + "sas_token", + "bearer_token", + "tenant_id", + "client_id", + "client_secret", + "use_fabric_endpoint", + "anonymous", + "endpoint_url", + "use_ssl", + ], + 0, + )?; + + let storage_account = args.try_get_named::("storage_account")?; + let access_key = args.try_get_named::("access_key")?; + let sas_token = args.try_get_named::("sas_token")?; + let bearer_token = args.try_get_named::("bearer_token")?; + let tenant_id = args.try_get_named::("tenant_id")?; + let client_id = args.try_get_named::("client_id")?; + let client_secret = args.try_get_named::("client_secret")?; + let use_fabric_endpoint = args.try_get_named::("use_fabric_endpoint")?; + let anonymous = args.try_get_named::("anonymous")?; + let endpoint_url = args.try_get_named::("endpoint_url")?; + let use_ssl = args.try_get_named::("use_ssl")?; + + let entries = vec![ + ( + Field::new("variant", DataType::Utf8), + literal_value("azure"), + ), + item!(storage_account, Utf8), + item!(access_key, Utf8), + item!(sas_token, Utf8), + item!(bearer_token, Utf8), + item!(tenant_id, Utf8), + item!(client_id, Utf8), + item!(client_secret, Utf8), + item!(use_fabric_endpoint, Boolean), + item!(anonymous, Boolean), + item!(endpoint_url, Utf8), + item!(use_ssl, Boolean), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub struct GCSConfigFunction; + +impl SQLFunction for GCSConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &["project_id", "credentials", "token", "anonymous"], + 0, + )?; + + let project_id = args.try_get_named::("project_id")?; + let credentials = args.try_get_named::("credentials")?; + let token = args.try_get_named::("token")?; + let anonymous = args.try_get_named::("anonymous")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("gcs")), + item!(project_id, Utf8), + item!(credentials, Utf8), + item!(token, Utf8), + item!(anonymous, Boolean), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub(crate) fn expr_to_iocfg(expr: &ExprRef) -> SQLPlannerResult { + // TODO(CORY): use serde to deserialize this + let Expr::Literal(LiteralValue::Struct(entries)) = expr.as_ref() else { + unsupported_sql_err!("Invalid IOConfig"); + }; + + macro_rules! get_value { + ($field:literal, $type:ident) => { + entries + .get(&Field::new($field, DataType::$type)) + .and_then(|s| match s { + LiteralValue::$type(s) => Some(Ok(s.clone())), + LiteralValue::Null => None, + _ => Some(Err(PlannerError::invalid_argument($field, "IOConfig"))), + }) + .transpose() + }; + } + + let variant = get_value!("variant", Utf8)? + .expect("variant is required for IOConfig, this indicates a programming error"); + + match variant.as_ref() { + "s3" => { + let region_name = get_value!("region_name", Utf8)?; + let endpoint_url = get_value!("endpoint_url", Utf8)?; + let key_id = get_value!("key_id", Utf8)?; + let session_token = get_value!("session_token", Utf8)?.map(|s| s.into()); + let access_key = get_value!("access_key", Utf8)?.map(|s| s.into()); + let buffer_time = get_value!("buffer_time", UInt64)?; + let max_connections_per_io_thread = + get_value!("max_connections_per_io_thread", UInt32)?; + let retry_initial_backoff_ms = get_value!("retry_initial_backoff_ms", UInt64)?; + let connect_timeout_ms = get_value!("connect_timeout_ms", UInt64)?; + let read_timeout_ms = get_value!("read_timeout_ms", UInt64)?; + let num_tries = get_value!("num_tries", UInt32)?; + let retry_mode = get_value!("retry_mode", Utf8)?; + let anonymous = get_value!("anonymous", Boolean)?; + let use_ssl = get_value!("use_ssl", Boolean)?; + let verify_ssl = get_value!("verify_ssl", Boolean)?; + let check_hostname_ssl = get_value!("check_hostname_ssl", Boolean)?; + let requester_pays = get_value!("requester_pays", Boolean)?; + let force_virtual_addressing = get_value!("force_virtual_addressing", Boolean)?; + let profile_name = get_value!("profile_name", Utf8)?; + let default = S3Config::default(); + let s3_config = S3Config { + region_name, + endpoint_url, + key_id, + session_token, + access_key, + credentials_provider: None, + buffer_time, + max_connections_per_io_thread: max_connections_per_io_thread + .unwrap_or(default.max_connections_per_io_thread), + retry_initial_backoff_ms: retry_initial_backoff_ms + .unwrap_or(default.retry_initial_backoff_ms), + connect_timeout_ms: connect_timeout_ms.unwrap_or(default.connect_timeout_ms), + read_timeout_ms: read_timeout_ms.unwrap_or(default.read_timeout_ms), + num_tries: num_tries.unwrap_or(default.num_tries), + retry_mode, + anonymous: anonymous.unwrap_or(default.anonymous), + use_ssl: use_ssl.unwrap_or(default.use_ssl), + verify_ssl: verify_ssl.unwrap_or(default.verify_ssl), + check_hostname_ssl: check_hostname_ssl.unwrap_or(default.check_hostname_ssl), + requester_pays: requester_pays.unwrap_or(default.requester_pays), + force_virtual_addressing: force_virtual_addressing + .unwrap_or(default.force_virtual_addressing), + profile_name, + }; + + Ok(IOConfig { + s3: s3_config, + ..Default::default() + }) + } + "http" => { + let default = HTTPConfig::default(); + let user_agent = get_value!("user_agent", Utf8)?.unwrap_or(default.user_agent); + let bearer_token = get_value!("bearer_token", Utf8)?.map(|s| s.into()); + + Ok(IOConfig { + http: HTTPConfig { + user_agent, + bearer_token, + }, + ..Default::default() + }) + } + "azure" => { + let storage_account = get_value!("storage_account", Utf8)?; + let access_key = get_value!("access_key", Utf8)?; + let sas_token = get_value!("sas_token", Utf8)?; + let bearer_token = get_value!("bearer_token", Utf8)?; + let tenant_id = get_value!("tenant_id", Utf8)?; + let client_id = get_value!("client_id", Utf8)?; + let client_secret = get_value!("client_secret", Utf8)?; + let use_fabric_endpoint = get_value!("use_fabric_endpoint", Boolean)?; + let anonymous = get_value!("anonymous", Boolean)?; + let endpoint_url = get_value!("endpoint_url", Utf8)?; + let use_ssl = get_value!("use_ssl", Boolean)?; + + let default = AzureConfig::default(); + + Ok(IOConfig { + azure: AzureConfig { + storage_account, + access_key: access_key.map(|s| s.into()), + sas_token, + bearer_token, + tenant_id, + client_id, + client_secret: client_secret.map(|s| s.into()), + use_fabric_endpoint: use_fabric_endpoint.unwrap_or(default.use_fabric_endpoint), + anonymous: anonymous.unwrap_or(default.anonymous), + endpoint_url, + use_ssl: use_ssl.unwrap_or(default.use_ssl), + }, + ..Default::default() + }) + } + "gcs" => { + let project_id = get_value!("project_id", Utf8)?; + let credentials = get_value!("credentials", Utf8)?; + let token = get_value!("token", Utf8)?; + let anonymous = get_value!("anonymous", Boolean)?; + let default = GCSConfig::default(); + + Ok(IOConfig { + gcs: GCSConfig { + project_id, + credentials: credentials.map(|s| s.into()), + token, + anonymous: anonymous.unwrap_or(default.anonymous), + }, + ..Default::default() + }) + } + _ => { + unreachable!("variant is required for IOConfig, this indicates a programming error") + } + } +} diff --git a/src/daft-sql/src/modules/float.rs b/src/daft-sql/src/modules/float.rs index 4cfffe34b4..292a5c4d85 100644 --- a/src/daft-sql/src/modules/float.rs +++ b/src/daft-sql/src/modules/float.rs @@ -37,6 +37,14 @@ impl SQLFunction for SQLFillNan { _ => unsupported_sql_err!("Invalid arguments for 'fill_nan': '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::FILL_NAN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "fill_value"] + } } pub struct SQLIsInf {} @@ -52,6 +60,14 @@ impl SQLFunction for SQLIsInf { _ => unsupported_sql_err!("Invalid arguments for 'is_inf': '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::IS_INF_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLIsNan {} @@ -67,6 +83,14 @@ impl SQLFunction for SQLIsNan { _ => unsupported_sql_err!("Invalid arguments for 'is_nan': '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::IS_NAN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLNotNan {} @@ -82,4 +106,26 @@ impl SQLFunction for SQLNotNan { _ => unsupported_sql_err!("Invalid arguments for 'not_nan': '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::NOT_NAN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } +} + +mod static_docs { + pub(crate) const FILL_NAN_DOCSTRING: &str = + "Replaces NaN values in the input expression with a specified fill value."; + + pub(crate) const IS_INF_DOCSTRING: &str = + "Checks if the input expression is infinite (positive or negative infinity)."; + + pub(crate) const IS_NAN_DOCSTRING: &str = + "Checks if the input expression is NaN (Not a Number)."; + + pub(crate) const NOT_NAN_DOCSTRING: &str = + "Checks if the input expression is not NaN (Not a Number)."; } diff --git a/src/daft-sql/src/modules/hashing.rs b/src/daft-sql/src/modules/hashing.rs new file mode 100644 index 0000000000..4259ebd04a --- /dev/null +++ b/src/daft-sql/src/modules/hashing.rs @@ -0,0 +1,111 @@ +use daft_dsl::ExprRef; +use daft_functions::{ + hash::hash, + minhash::{minhash, MinHashFunction}, +}; +use sqlparser::ast::FunctionArg; + +use super::SQLModule; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, + unsupported_sql_err, +}; + +pub struct SQLModuleHashing; + +impl SQLModule for SQLModuleHashing { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("hash", SQLHash); + parent.add_fn("minhash", SQLMinhash); + } +} + +pub struct SQLHash; + +impl SQLFunction for SQLHash { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(hash(input, None)) + } + [input, seed] => { + let input = planner.plan_function_arg(input)?; + match seed { + FunctionArg::Named { name, arg, .. } if name.value == "seed" => { + let seed = planner.try_unwrap_function_arg_expr(arg)?; + Ok(hash(input, Some(seed))) + } + arg @ FunctionArg::Unnamed(_) => { + let seed = planner.plan_function_arg(arg)?; + Ok(hash(input, Some(seed))) + } + _ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"), + } + } + _ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"), + } + } +} + +pub struct SQLMinhash; + +impl TryFrom for MinHashFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let num_hashes = args + .get_named("num_hashes") + .ok_or_else(|| PlannerError::invalid_operation("num_hashes is required"))? + .as_literal() + .and_then(|lit| lit.as_i64()) + .ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer"))? + as usize; + + let ngram_size = args + .get_named("ngram_size") + .ok_or_else(|| PlannerError::invalid_operation("ngram_size is required"))? + .as_literal() + .and_then(|lit| lit.as_i64()) + .ok_or_else(|| PlannerError::invalid_operation("ngram_size must be an integer"))? + as usize; + let seed = args + .get_named("seed") + .map(|arg| { + arg.as_literal() + .and_then(|lit| lit.as_i64()) + .ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer")) + }) + .transpose()? + .unwrap_or(1) as u32; + Ok(Self { + num_hashes, + ngram_size, + seed, + }) + } +} + +impl SQLFunction for SQLMinhash { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: MinHashFunction = + planner.plan_function_args(args, &["num_hashes", "ngram_size", "seed"], 0)?; + + Ok(minhash(input, args.num_hashes, args.ngram_size, args.seed)) + } + _ => unsupported_sql_err!("Invalid arguments for minhash: '{inputs:?}'"), + } + } +} diff --git a/src/daft-sql/src/modules/image/crop.rs b/src/daft-sql/src/modules/image/crop.rs index 36c72fcca3..286208889c 100644 --- a/src/daft-sql/src/modules/image/crop.rs +++ b/src/daft-sql/src/modules/image/crop.rs @@ -21,4 +21,12 @@ impl SQLFunction for SQLImageCrop { _ => unsupported_sql_err!("Invalid arguments for image_crop: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Crops an image to a specified bounding box. The bounding box is specified as [x, y, width, height].".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input_image", "bounding_box"] + } } diff --git a/src/daft-sql/src/modules/image/decode.rs b/src/daft-sql/src/modules/image/decode.rs index a6b95d538d..a896c67a05 100644 --- a/src/daft-sql/src/modules/image/decode.rs +++ b/src/daft-sql/src/modules/image/decode.rs @@ -61,4 +61,12 @@ impl SQLFunction for SQLImageDecode { _ => unsupported_sql_err!("Invalid arguments for image_decode: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Decodes an image from binary data. Optionally, you can specify the image mode and error handling behavior.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "mode", "on_error"] + } } diff --git a/src/daft-sql/src/modules/image/encode.rs b/src/daft-sql/src/modules/image/encode.rs index a902179f88..acf489c807 100644 --- a/src/daft-sql/src/modules/image/encode.rs +++ b/src/daft-sql/src/modules/image/encode.rs @@ -46,4 +46,12 @@ impl SQLFunction for SQLImageEncode { _ => unsupported_sql_err!("Invalid arguments for image_encode: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Encodes an image into the specified image file format, returning a binary column of encoded bytes.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input_image", "image_format"] + } } diff --git a/src/daft-sql/src/modules/image/resize.rs b/src/daft-sql/src/modules/image/resize.rs index 8ce37eb7f8..ac6c12fd50 100644 --- a/src/daft-sql/src/modules/image/resize.rs +++ b/src/daft-sql/src/modules/image/resize.rs @@ -16,7 +16,7 @@ impl TryFrom for ImageResize { fn try_from(args: SQLFunctionArguments) -> Result { let width = args .get_named("w") - .or_else(|| args.get_unnamed(0)) + .or_else(|| args.get_positional(0)) .map(|arg| match arg.as_ref() { Expr::Literal(LiteralValue::Int64(i)) => Ok(*i), _ => unsupported_sql_err!("Expected width to be a number"), @@ -28,7 +28,7 @@ impl TryFrom for ImageResize { let height = args .get_named("h") - .or_else(|| args.get_unnamed(1)) + .or_else(|| args.get_positional(1)) .map(|arg| match arg.as_ref() { Expr::Literal(LiteralValue::Int64(i)) => Ok(*i), _ => unsupported_sql_err!("Expected height to be a number"), @@ -64,4 +64,12 @@ impl SQLFunction for SQLImageResize { _ => unsupported_sql_err!("Invalid arguments for image_resize: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Resizes an image to the specified width and height.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input_image", "width", "height"] + } } diff --git a/src/daft-sql/src/modules/image/to_mode.rs b/src/daft-sql/src/modules/image/to_mode.rs index a02efb2d36..b5b9202d1f 100644 --- a/src/daft-sql/src/modules/image/to_mode.rs +++ b/src/daft-sql/src/modules/image/to_mode.rs @@ -41,4 +41,12 @@ impl SQLFunction for SQLImageToMode { _ => unsupported_sql_err!("Invalid arguments for image_encode: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Converts an image to the specified mode (e.g. RGB, RGBA, Grayscale).".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input_image", "mode"] + } } diff --git a/src/daft-sql/src/modules/json.rs b/src/daft-sql/src/modules/json.rs index f0d600daea..8dc9e617f5 100644 --- a/src/daft-sql/src/modules/json.rs +++ b/src/daft-sql/src/modules/json.rs @@ -35,4 +35,17 @@ impl SQLFunction for JsonQuery { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::JSON_QUERY_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "query"] + } +} + +mod static_docs { + pub(crate) const JSON_QUERY_DOCSTRING: &str = + "Extracts a JSON object from a JSON string using a JSONPath expression."; } diff --git a/src/daft-sql/src/modules/list.rs b/src/daft-sql/src/modules/list.rs index b9e52d9748..bd6db25990 100644 --- a/src/daft-sql/src/modules/list.rs +++ b/src/daft-sql/src/modules/list.rs @@ -55,6 +55,14 @@ impl SQLFunction for SQLListChunk { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_CHUNK_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "chunk_size"] + } } pub struct SQLListCount; @@ -86,6 +94,14 @@ impl SQLFunction for SQLListCount { _ => unsupported_sql_err!("invalid arguments for list_count. Expected either list_count(expr) or list_count(expr, mode)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_COUNT_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "mode"] + } } pub struct SQLExplode; @@ -104,6 +120,14 @@ impl SQLFunction for SQLExplode { _ => unsupported_sql_err!("Expected 1 argument"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::EXPLODE_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListJoin; @@ -125,6 +149,14 @@ impl SQLFunction for SQLListJoin { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_JOIN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "separator"] + } } pub struct SQLListMax; @@ -143,6 +175,14 @@ impl SQLFunction for SQLListMax { _ => unsupported_sql_err!("invalid arguments for list_max. Expected list_max(expr)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_MAX_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListMean; @@ -161,6 +201,14 @@ impl SQLFunction for SQLListMean { _ => unsupported_sql_err!("invalid arguments for list_mean. Expected list_mean(expr)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_MEAN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListMin; @@ -179,6 +227,14 @@ impl SQLFunction for SQLListMin { _ => unsupported_sql_err!("invalid arguments for list_min. Expected list_min(expr)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_MIN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListSum; @@ -197,6 +253,14 @@ impl SQLFunction for SQLListSum { _ => unsupported_sql_err!("invalid arguments for list_sum. Expected list_sum(expr)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_SUM_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListSlice; @@ -219,6 +283,14 @@ impl SQLFunction for SQLListSlice { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_SLICE_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "start", "end"] + } } pub struct SQLListSort; @@ -258,4 +330,38 @@ impl SQLFunction for SQLListSort { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_SORT_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "order"] + } +} + +mod static_docs { + pub(crate) const LIST_CHUNK_DOCSTRING: &str = "Splits a list into chunks of a specified size."; + + pub(crate) const LIST_COUNT_DOCSTRING: &str = "Counts the number of elements in a list."; + + pub(crate) const EXPLODE_DOCSTRING: &str = "Expands a list column into multiple rows."; + + pub(crate) const LIST_JOIN_DOCSTRING: &str = + "Joins elements of a list into a single string using a specified separator."; + + pub(crate) const LIST_MAX_DOCSTRING: &str = "Returns the maximum value in a list."; + + pub(crate) const LIST_MEAN_DOCSTRING: &str = + "Calculates the mean (average) of values in a list."; + + pub(crate) const LIST_MIN_DOCSTRING: &str = "Returns the minimum value in a list."; + + pub(crate) const LIST_SUM_DOCSTRING: &str = "Calculates the sum of values in a list."; + + pub(crate) const LIST_SLICE_DOCSTRING: &str = + "Extracts a portion of a list from a start index to an end index."; + + pub(crate) const LIST_SORT_DOCSTRING: &str = + "Sorts the elements of a list in ascending or descending order."; } diff --git a/src/daft-sql/src/modules/map.rs b/src/daft-sql/src/modules/map.rs index d3a328f3a4..0ae5aca2be 100644 --- a/src/daft-sql/src/modules/map.rs +++ b/src/daft-sql/src/modules/map.rs @@ -30,4 +30,23 @@ impl SQLFunction for MapGet { _ => invalid_operation_err!("Expected 2 input args"), } } + + fn docstrings(&self, alias: &str) -> String { + static_docs::MAP_GET_DOCSTRING.replace("{}", alias) + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "key"] + } +} + +mod static_docs { + pub(crate) const MAP_GET_DOCSTRING: &str = + "Retrieves the value associated with a given key from a map. + +.. seealso:: + + * :func:`~daft.sql._sql_funcs.map_get` + * :func:`~daft.sql._sql_funcs.map_extract` +"; } diff --git a/src/daft-sql/src/modules/mod.rs b/src/daft-sql/src/modules/mod.rs index 0f60ecbff9..af4cb731a3 100644 --- a/src/daft-sql/src/modules/mod.rs +++ b/src/daft-sql/src/modules/mod.rs @@ -1,7 +1,9 @@ use crate::functions::SQLFunctions; pub mod aggs; +pub mod config; pub mod float; +pub mod hashing; pub mod image; pub mod json; pub mod list; diff --git a/src/daft-sql/src/modules/numeric.rs b/src/daft-sql/src/modules/numeric.rs index 197d958860..21ac2a0873 100644 --- a/src/daft-sql/src/modules/numeric.rs +++ b/src/daft-sql/src/modules/numeric.rs @@ -88,6 +88,67 @@ impl SQLFunction for SQLNumericExpr { let inputs = self.args_to_expr_unnamed(inputs, planner)?; to_expr(self, inputs.as_slice()) } + + fn docstrings(&self, _alias: &str) -> String { + let docstring = match self { + Self::Abs => "Gets the absolute value of a number.", + Self::Ceil => "Rounds a number up to the nearest integer.", + Self::Exp => "Calculates the exponential of a number (e^x).", + Self::Floor => "Rounds a number down to the nearest integer.", + Self::Round => "Rounds a number to a specified number of decimal places.", + Self::Sign => "Returns the sign of a number (-1, 0, or 1).", + Self::Sqrt => "Calculates the square root of a number.", + Self::Sin => "Calculates the sine of an angle in radians.", + Self::Cos => "Calculates the cosine of an angle in radians.", + Self::Tan => "Calculates the tangent of an angle in radians.", + Self::Cot => "Calculates the cotangent of an angle in radians.", + Self::ArcSin => "Calculates the inverse sine (arc sine) of a number.", + Self::ArcCos => "Calculates the inverse cosine (arc cosine) of a number.", + Self::ArcTan => "Calculates the inverse tangent (arc tangent) of a number.", + Self::ArcTan2 => { + "Calculates the angle between the positive x-axis and the ray from (0,0) to (x,y)." + } + Self::Radians => "Converts an angle from degrees to radians.", + Self::Degrees => "Converts an angle from radians to degrees.", + Self::Log => "Calculates the natural logarithm of a number.", + Self::Log2 => "Calculates the base-2 logarithm of a number.", + Self::Log10 => "Calculates the base-10 logarithm of a number.", + Self::Ln => "Calculates the natural logarithm of a number.", + Self::ArcTanh => "Calculates the inverse hyperbolic tangent of a number.", + Self::ArcCosh => "Calculates the inverse hyperbolic cosine of a number.", + Self::ArcSinh => "Calculates the inverse hyperbolic sine of a number.", + }; + docstring.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + match self { + Self::Abs + | Self::Ceil + | Self::Floor + | Self::Sign + | Self::Sqrt + | Self::Sin + | Self::Cos + | Self::Tan + | Self::Cot + | Self::ArcSin + | Self::ArcCos + | Self::ArcTan + | Self::Radians + | Self::Degrees + | Self::Log2 + | Self::Log10 + | Self::Ln + | Self::ArcTanh + | Self::ArcCosh + | Self::ArcSinh => &["input"], + Self::Log => &["input", "base"], + Self::Round => &["input", "precision"], + Self::Exp => &["input", "exponent"], + Self::ArcTan2 => &["y", "x"], + } + } } fn to_expr(expr: &SQLNumericExpr, args: &[ExprRef]) -> SQLPlannerResult { diff --git a/src/daft-sql/src/modules/partitioning.rs b/src/daft-sql/src/modules/partitioning.rs index e833edd51d..def20b2774 100644 --- a/src/daft-sql/src/modules/partitioning.rs +++ b/src/daft-sql/src/modules/partitioning.rs @@ -80,6 +80,28 @@ impl SQLFunction for PartitioningExpr { } } } + + fn docstrings(&self, _alias: &str) -> String { + match self { + Self::Years => "Extracts the number of years since epoch time from a datetime expression.".to_string(), + Self::Months => "Extracts the number of months since epoch time from a datetime expression.".to_string(), + Self::Days => "Extracts the number of days since epoch time from a datetime expression.".to_string(), + Self::Hours => "Extracts the number of hours since epoch time from a datetime expression.".to_string(), + Self::IcebergBucket(_) => "Computes a bucket number for the input expression based the specified number of buckets using an Iceberg-specific hash.".to_string(), + Self::IcebergTruncate(_) => "Truncates the input expression to a specified width.".to_string(), + } + } + + fn arg_names(&self) -> &'static [&'static str] { + match self { + Self::Years => &["input"], + Self::Months => &["input"], + Self::Days => &["input"], + Self::Hours => &["input"], + Self::IcebergBucket(_) => &["input", "num_buckets"], + Self::IcebergTruncate(_) => &["input", "width"], + } + } } fn partitioning_helper daft_dsl::ExprRef>( diff --git a/src/daft-sql/src/modules/structs.rs b/src/daft-sql/src/modules/structs.rs index 66be42d8e3..17fae85c9e 100644 --- a/src/daft-sql/src/modules/structs.rs +++ b/src/daft-sql/src/modules/structs.rs @@ -34,4 +34,12 @@ impl SQLFunction for StructGet { _ => invalid_operation_err!("Expected 2 input args"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Extracts a field from a struct expression by name.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "field"] + } } diff --git a/src/daft-sql/src/modules/temporal.rs b/src/daft-sql/src/modules/temporal.rs index 58687724fa..840c278765 100644 --- a/src/daft-sql/src/modules/temporal.rs +++ b/src/daft-sql/src/modules/temporal.rs @@ -50,6 +50,16 @@ macro_rules! temporal { ), } } + fn docstrings(&self, _alias: &str) -> String { + format!( + "Extracts the {} component from a datetime expression.", + stringify!($fn_name).replace("dt_", "") + ) + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } }; } diff --git a/src/daft-sql/src/modules/utf8.rs b/src/daft-sql/src/modules/utf8.rs index 263a8bd9e7..c31879cd82 100644 --- a/src/daft-sql/src/modules/utf8.rs +++ b/src/daft-sql/src/modules/utf8.rs @@ -1,12 +1,22 @@ +use daft_core::array::ops::Utf8NormalizeOptions; use daft_dsl::{ - functions::{self, utf8::Utf8Expr}, + functions::{ + self, + utf8::{normalize, Utf8Expr}, + }, ExprRef, LiteralValue, }; +use daft_functions::{ + count_matches::{utf8_count_matches, CountMatchesFunction}, + tokenize::{tokenize_decode, tokenize_encode, TokenizeDecodeFunction, TokenizeEncodeFunction}, +}; use super::SQLModule; use crate::{ - ensure, error::SQLPlannerResult, functions::SQLFunction, invalid_operation_err, - unsupported_sql_err, + ensure, + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments}, + invalid_operation_err, unsupported_sql_err, }; pub struct SQLModuleUtf8; @@ -17,16 +27,18 @@ impl SQLModule for SQLModuleUtf8 { parent.add_fn("ends_with", EndsWith); parent.add_fn("starts_with", StartsWith); parent.add_fn("contains", Contains); - parent.add_fn("split", Split(true)); + parent.add_fn("split", Split(false)); // TODO add split variants // parent.add("split", f(Split(false))); - parent.add_fn("match", Match); - parent.add_fn("extract", Extract(0)); - parent.add_fn("extract_all", ExtractAll(0)); - parent.add_fn("replace", Replace(true)); + parent.add_fn("regexp_match", Match); + parent.add_fn("regexp_extract", Extract(0)); + parent.add_fn("regexp_extract_all", ExtractAll(0)); + parent.add_fn("regexp_replace", Replace(true)); + parent.add_fn("regexp_split", Split(true)); // TODO add replace variants // parent.add("replace", f(Replace(false))); parent.add_fn("length", Length); + parent.add_fn("length_bytes", LengthBytes); parent.add_fn("lower", Lower); parent.add_fn("upper", Upper); parent.add_fn("lstrip", Lstrip); @@ -39,13 +51,13 @@ impl SQLModule for SQLModuleUtf8 { parent.add_fn("rpad", Rpad); parent.add_fn("lpad", Lpad); parent.add_fn("repeat", Repeat); - parent.add_fn("like", Like); - parent.add_fn("ilike", Ilike); - parent.add_fn("substr", Substr); + parent.add_fn("to_date", ToDate("".to_string())); parent.add_fn("to_datetime", ToDatetime("".to_string(), None)); - // TODO add normalization variants. - // parent.add("normalize", f(Normalize(Default::default()))); + parent.add_fn("count_matches", SQLCountMatches); + parent.add_fn("normalize", SQLNormalize); + parent.add_fn("tokenize_encode", SQLTokenizeEncode); + parent.add_fn("tokenize_decode", SQLTokenizeDecode); } } @@ -60,6 +72,72 @@ impl SQLFunction for Utf8Expr { let inputs = self.args_to_expr_unnamed(inputs, planner)?; to_expr(self, &inputs) } + + fn docstrings(&self, _alias: &str) -> String { + match self { + Self::EndsWith => "Returns true if the string ends with the specified substring".to_string(), + Self::StartsWith => "Returns true if the string starts with the specified substring".to_string(), + Self::Contains => "Returns true if the string contains the specified substring".to_string(), + Self::Split(_) => "Splits the string by the specified delimiter and returns an array of substrings".to_string(), + Self::Match => "Returns true if the string matches the specified regular expression pattern".to_string(), + Self::Extract(_) => "Extracts the first substring that matches the specified regular expression pattern".to_string(), + Self::ExtractAll(_) => "Extracts all substrings that match the specified regular expression pattern".to_string(), + Self::Replace(_) => "Replaces all occurrences of a substring with a new string".to_string(), + Self::Like => "Returns true if the string matches the specified SQL LIKE pattern".to_string(), + Self::Ilike => "Returns true if the string matches the specified SQL LIKE pattern (case-insensitive)".to_string(), + Self::Length => "Returns the length of the string".to_string(), + Self::Lower => "Converts the string to lowercase".to_string(), + Self::Upper => "Converts the string to uppercase".to_string(), + Self::Lstrip => "Removes leading whitespace from the string".to_string(), + Self::Rstrip => "Removes trailing whitespace from the string".to_string(), + Self::Reverse => "Reverses the order of characters in the string".to_string(), + Self::Capitalize => "Capitalizes the first character of the string".to_string(), + Self::Left => "Returns the specified number of leftmost characters from the string".to_string(), + Self::Right => "Returns the specified number of rightmost characters from the string".to_string(), + Self::Find => "Returns the index of the first occurrence of a substring within the string".to_string(), + Self::Rpad => "Pads the string on the right side with the specified string until it reaches the specified length".to_string(), + Self::Lpad => "Pads the string on the left side with the specified string until it reaches the specified length".to_string(), + Self::Repeat => "Repeats the string the specified number of times".to_string(), + Self::Substr => "Returns a substring of the string starting at the specified position and length".to_string(), + Self::ToDate(_) => "Parses the string as a date using the specified format.".to_string(), + Self::ToDatetime(_, _) => "Parses the string as a datetime using the specified format.".to_string(), + Self::LengthBytes => "Returns the length of the string in bytes".to_string(), + Self::Normalize(_) => unimplemented!("Normalize not implemented"), + } + } + + fn arg_names(&self) -> &'static [&'static str] { + match self { + Self::EndsWith => &["string_input", "substring"], + Self::StartsWith => &["string_input", "substring"], + Self::Contains => &["string_input", "substring"], + Self::Split(_) => &["string_input", "delimiter"], + Self::Match => &["string_input", "pattern"], + Self::Extract(_) => &["string_input", "pattern"], + Self::ExtractAll(_) => &["string_input", "pattern"], + Self::Replace(_) => &["string_input", "pattern", "replacement"], + Self::Like => &["string_input", "pattern"], + Self::Ilike => &["string_input", "pattern"], + Self::Length => &["string_input"], + Self::Lower => &["string_input"], + Self::Upper => &["string_input"], + Self::Lstrip => &["string_input"], + Self::Rstrip => &["string_input"], + Self::Reverse => &["string_input"], + Self::Capitalize => &["string_input"], + Self::Left => &["string_input", "length"], + Self::Right => &["string_input", "length"], + Self::Find => &["string_input", "substring"], + Self::Rpad => &["string_input", "length", "pad"], + Self::Lpad => &["string_input", "length", "pad"], + Self::Repeat => &["string_input", "count"], + Self::Substr => &["string_input", "start", "length"], + Self::ToDate(_) => &["string_input", "format"], + Self::ToDatetime(_, _) => &["string_input", "format"], + Self::LengthBytes => &["string_input"], + Self::Normalize(_) => unimplemented!("Normalize not implemented"), + } + } } fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { @@ -78,19 +156,44 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { ensure!(args.len() == 2, "contains takes exactly two arguments"); Ok(contains(args[0].clone(), args[1].clone())) } - Split(_) => { + Split(true) => { + ensure!(args.len() == 2, "split takes exactly two arguments"); + Ok(split(args[0].clone(), args[1].clone(), true)) + } + Split(false) => { ensure!(args.len() == 2, "split takes exactly two arguments"); Ok(split(args[0].clone(), args[1].clone(), false)) } Match => { - unsupported_sql_err!("match") - } - Extract(_) => { - unsupported_sql_err!("extract") - } - ExtractAll(_) => { - unsupported_sql_err!("extract_all") + ensure!(args.len() == 2, "regexp_match takes exactly two arguments"); + Ok(match_(args[0].clone(), args[1].clone())) } + Extract(_) => match args { + [input, pattern] => Ok(extract(input.clone(), pattern.clone(), 0)), + [input, pattern, idx] => { + let idx = idx.as_literal().and_then(|lit| lit.as_i64()).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {:?}", idx)) + })?; + + Ok(extract(input.clone(), pattern.clone(), idx as usize)) + } + _ => { + invalid_operation_err!("regexp_extract takes exactly two or three arguments") + } + }, + ExtractAll(_) => match args { + [input, pattern] => Ok(extract_all(input.clone(), pattern.clone(), 0)), + [input, pattern, idx] => { + let idx = idx.as_literal().and_then(|lit| lit.as_i64()).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {:?}", idx)) + })?; + + Ok(extract_all(input.clone(), pattern.clone(), idx as usize)) + } + _ => { + invalid_operation_err!("regexp_extract_all takes exactly two or three arguments") + } + }, Replace(_) => { ensure!(args.len() == 3, "replace takes exactly three arguments"); Ok(replace( @@ -101,10 +204,10 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { )) } Like => { - unsupported_sql_err!("like") + unreachable!("like should be handled by the parser") } Ilike => { - unsupported_sql_err!("ilike") + unreachable!("ilike should be handled by the parser") } Length => { ensure!(args.len() == 1, "length takes exactly one argument"); @@ -163,8 +266,7 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { Ok(repeat(args[0].clone(), args[1].clone())) } Substr => { - ensure!(args.len() == 3, "substr takes exactly three arguments"); - Ok(substr(args[0].clone(), args[1].clone(), args[2].clone())) + unreachable!("substr should be handled by the parser") } ToDate(_) => { ensure!(args.len() == 2, "to_date takes exactly two arguments"); @@ -195,3 +297,233 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { } } } + +pub struct SQLCountMatches; + +impl TryFrom for CountMatchesFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let whole_words = args.try_get_named("whole_words")?.unwrap_or(false); + let case_sensitive = args.try_get_named("case_sensitive")?.unwrap_or(true); + + Ok(Self { + whole_words, + case_sensitive, + }) + } +} + +impl SQLFunction for SQLCountMatches { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, pattern] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + Ok(utf8_count_matches(input, pattern, false, true)) + } + [input, pattern, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + let args: CountMatchesFunction = + planner.plan_function_args(args, &["whole_words", "case_sensitive"], 0)?; + + Ok(utf8_count_matches( + input, + pattern, + args.whole_words, + args.case_sensitive, + )) + } + _ => Err(PlannerError::invalid_operation( + "Invalid arguments for count_matches: '{inputs:?}'", + )), + } + } +} + +pub struct SQLNormalize; + +impl TryFrom for Utf8NormalizeOptions { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let remove_punct = args.try_get_named("remove_punct")?.unwrap_or(false); + let lowercase = args.try_get_named("lowercase")?.unwrap_or(false); + let nfd_unicode = args.try_get_named("nfd_unicode")?.unwrap_or(false); + let white_space = args.try_get_named("white_space")?.unwrap_or(false); + + Ok(Self { + remove_punct, + lowercase, + nfd_unicode, + white_space, + }) + } +} + +impl SQLFunction for SQLNormalize { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(normalize(input, Utf8NormalizeOptions::default())) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: Utf8NormalizeOptions = planner.plan_function_args( + args, + &["remove_punct", "lowercase", "nfd_unicode", "white_space"], + 0, + )?; + Ok(normalize(input, args)) + } + _ => invalid_operation_err!("Invalid arguments for normalize"), + } + } +} + +pub struct SQLTokenizeEncode; +impl TryFrom for TokenizeEncodeFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + if args.get_named("io_config").is_some() { + return Err(PlannerError::invalid_operation( + "io_config argument is not yet supported for tokenize_encode", + )); + } + + let tokens_path = args.try_get_named("tokens_path")?.ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument is required for tokenize_encode") + })?; + + let pattern = args.try_get_named("pattern")?; + let special_tokens = args.try_get_named("special_tokens")?; + let use_special_tokens = args.try_get_named("use_special_tokens")?.unwrap_or(false); + + Ok(Self { + tokens_path, + pattern, + special_tokens, + use_special_tokens, + io_config: None, + }) + } +} + +impl SQLFunction for SQLTokenizeEncode { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, tokens_path] => { + let input = planner.plan_function_arg(input)?; + let tokens_path = planner.plan_function_arg(tokens_path)?; + let tokens_path = tokens_path + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument must be a string") + })?; + Ok(tokenize_encode(input, tokens_path, None, None, None, false)) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: TokenizeEncodeFunction = planner.plan_function_args( + args, + &[ + "tokens_path", + "pattern", + "special_tokens", + "use_special_tokens", + ], + 1, // tokens_path can be named or positional + )?; + Ok(tokenize_encode( + input, + &args.tokens_path, + None, + args.pattern.as_deref(), + args.special_tokens.as_deref(), + args.use_special_tokens, + )) + } + _ => invalid_operation_err!("Invalid arguments for tokenize_encode"), + } + } +} + +pub struct SQLTokenizeDecode; +impl TryFrom for TokenizeDecodeFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + if args.get_named("io_config").is_some() { + return Err(PlannerError::invalid_operation( + "io_config argument is not yet supported for tokenize_decode", + )); + } + + let tokens_path = args.try_get_named("tokens_path")?.ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument is required for tokenize_encode") + })?; + + let pattern = args.try_get_named("pattern")?; + let special_tokens = args.try_get_named("special_tokens")?; + + Ok(Self { + tokens_path, + pattern, + special_tokens, + io_config: None, + }) + } +} +impl SQLFunction for SQLTokenizeDecode { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, tokens_path] => { + let input = planner.plan_function_arg(input)?; + let tokens_path = planner.plan_function_arg(tokens_path)?; + let tokens_path = tokens_path + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument must be a string") + })?; + Ok(tokenize_decode(input, tokens_path, None, None, None)) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: TokenizeDecodeFunction = planner.plan_function_args( + args, + &["tokens_path", "pattern", "special_tokens"], + 1, // tokens_path can be named or positional + )?; + Ok(tokenize_decode( + input, + &args.tokens_path, + None, + args.pattern.as_deref(), + args.special_tokens.as_deref(), + )) + } + _ => invalid_operation_err!("Invalid arguments for tokenize_decode"), + } + } +} diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 2aefab9f96..1be5b724a9 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -12,8 +12,8 @@ use daft_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ ArrayElemTypeDef, BinaryOperator, CastKind, ExactNumberInfo, GroupByExpr, Ident, Query, - SelectItem, StructField, Subscript, TableWithJoins, TimezoneInfo, UnaryOperator, Value, - WildcardAdditionalOptions, + SelectItem, Statement, StructField, Subscript, TableWithJoins, TimezoneInfo, UnaryOperator, + Value, WildcardAdditionalOptions, }, dialect::GenericDialect, parser::{Parser, ParserOptions}, @@ -88,9 +88,18 @@ impl SQLPlanner { let statements = parser.parse_statements()?; - match statements.as_slice() { - [sqlparser::ast::Statement::Query(query)] => Ok(self.plan_query(query)?.build()), - other => unsupported_sql_err!("{}", other[0]), + match statements.len() { + 1 => Ok(self.plan_statement(&statements[0])?), + other => { + unsupported_sql_err!("Only exactly one SQL statement allowed, found {}", other) + } + } + } + + fn plan_statement(&mut self, statement: &Statement) -> SQLPlannerResult { + match statement { + Statement::Query(query) => Ok(self.plan_query(query)?.build()), + other => unsupported_sql_err!("{}", other), } } @@ -388,7 +397,19 @@ impl SQLPlanner { fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult { match rel { - sqlparser::ast::TableFactor::Table { name, .. } => { + sqlparser::ast::TableFactor::Table { + name, + args: Some(args), + alias, + .. + } => { + let tbl_fn = name.0.first().unwrap().value.as_str(); + + self.plan_table_function(tbl_fn, args, alias) + } + sqlparser::ast::TableFactor::Table { + name, args: None, .. + } => { let table_name = name.to_string(); let plan = self .catalog @@ -489,9 +510,9 @@ impl SQLPlanner { .collect::>() }) .map_err(|e| e.into()); + } else { + Ok(vec![col("*")]) } - - Ok(vec![]) } _ => todo!(), } @@ -617,11 +638,30 @@ impl SQLPlanner { SQLExpr::Ceil { expr, .. } => Ok(ceil(self.plan_expr(expr)?)), SQLExpr::Floor { expr, .. } => Ok(floor(self.plan_expr(expr)?)), SQLExpr::Position { .. } => unsupported_sql_err!("POSITION"), - SQLExpr::Substring { .. } => unsupported_sql_err!("SUBSTRING"), + SQLExpr::Substring { + expr, + substring_from, + substring_for, + special: true, // We only support SUBSTRING(expr, start, length) syntax + } => { + let (Some(substring_from), Some(substring_for)) = (substring_from, substring_for) + else { + unsupported_sql_err!("SUBSTRING") + }; + + let expr = self.plan_expr(expr)?; + let start = self.plan_expr(substring_from)?; + let length = self.plan_expr(substring_for)?; + + Ok(daft_dsl::functions::utf8::substr(expr, start, length)) + } + SQLExpr::Substring { special: false, .. } => { + unsupported_sql_err!("`SUBSTRING(expr [FROM start] [FOR len])` syntax") + } SQLExpr::Trim { .. } => unsupported_sql_err!("TRIM"), SQLExpr::Overlay { .. } => unsupported_sql_err!("OVERLAY"), SQLExpr::Collate { .. } => unsupported_sql_err!("COLLATE"), - SQLExpr::Nested(_) => unsupported_sql_err!("NESTED"), + SQLExpr::Nested(e) => self.plan_expr(e), SQLExpr::IntroducedString { .. } => unsupported_sql_err!("INTRODUCED STRING"), SQLExpr::TypedString { data_type, value } => match data_type { sqlparser::ast::DataType::Date => Ok(to_date(lit(value.as_str()), "%Y-%m-%d")), @@ -700,7 +740,22 @@ impl SQLPlanner { } SQLExpr::Struct { .. } => unsupported_sql_err!("STRUCT"), SQLExpr::Named { .. } => unsupported_sql_err!("NAMED"), - SQLExpr::Dictionary(_) => unsupported_sql_err!("DICTIONARY"), + SQLExpr::Dictionary(dict) => { + let entries = dict + .iter() + .map(|entry| { + let key = entry.key.value.clone(); + let value = self.plan_expr(&entry.value)?; + let value = value.as_literal().ok_or_else(|| { + PlannerError::invalid_operation("Dictionary value is not a literal") + })?; + let struct_field = Field::new(key, value.get_type()); + Ok((struct_field, value.clone())) + }) + .collect::>()?; + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } SQLExpr::Map(_) => unsupported_sql_err!("MAP"), SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()), SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"), diff --git a/src/daft-sql/src/python.rs b/src/daft-sql/src/python.rs index 9201b7fccc..b61d3fedd2 100644 --- a/src/daft-sql/src/python.rs +++ b/src/daft-sql/src/python.rs @@ -3,7 +3,32 @@ use daft_dsl::python::PyExpr; use daft_plan::{LogicalPlanBuilder, PyLogicalPlanBuilder}; use pyo3::prelude::*; -use crate::{catalog::SQLCatalog, planner::SQLPlanner}; +use crate::{catalog::SQLCatalog, functions::SQL_FUNCTIONS, planner::SQLPlanner}; + +#[pyclass] +pub struct SQLFunctionStub { + name: String, + docstring: String, + arg_names: Vec<&'static str>, +} + +#[pymethods] +impl SQLFunctionStub { + #[getter] + fn name(&self) -> PyResult { + Ok(self.name.clone()) + } + + #[getter] + fn docstring(&self) -> PyResult { + Ok(self.docstring.clone()) + } + + #[getter] + fn arg_names(&self) -> PyResult> { + Ok(self.arg_names.clone()) + } +} #[pyfunction] pub fn sql( @@ -22,6 +47,23 @@ pub fn sql_expr(sql: &str) -> PyResult { Ok(PyExpr { expr }) } +#[pyfunction] +pub fn list_sql_functions() -> Vec { + SQL_FUNCTIONS + .map + .keys() + .cloned() + .map(|name| { + let (docstring, args) = SQL_FUNCTIONS.docsmap.get(&name).unwrap(); + SQLFunctionStub { + name, + docstring: docstring.to_string(), + arg_names: args.to_vec(), + } + }) + .collect() +} + /// PyCatalog is the Python interface to the Catalog. #[pyclass(module = "daft.daft")] #[derive(Debug, Clone)] diff --git a/src/daft-sql/src/table_provider/mod.rs b/src/daft-sql/src/table_provider/mod.rs new file mode 100644 index 0000000000..453fa4965f --- /dev/null +++ b/src/daft-sql/src/table_provider/mod.rs @@ -0,0 +1,119 @@ +pub mod read_parquet; +use std::{collections::HashMap, sync::Arc}; + +use daft_plan::LogicalPlanBuilder; +use once_cell::sync::Lazy; +use read_parquet::ReadParquetFunction; +use sqlparser::ast::{TableAlias, TableFunctionArgs}; + +use crate::{ + error::SQLPlannerResult, + modules::config::expr_to_iocfg, + planner::{Relation, SQLPlanner}, + unsupported_sql_err, +}; + +pub(crate) static SQL_TABLE_FUNCTIONS: Lazy = Lazy::new(|| { + let mut functions = SQLTableFunctions::new(); + functions.add_fn("read_parquet", ReadParquetFunction); + #[cfg(feature = "python")] + functions.add_fn("read_deltalake", ReadDeltalakeFunction); + + functions +}); + +/// TODOs +/// - Use multimap for function variants. +/// - Add more functions.. +pub struct SQLTableFunctions { + pub(crate) map: HashMap>, +} + +impl SQLTableFunctions { + /// Create a new [SQLFunctions] instance. + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + /// Add a [FunctionExpr] to the [SQLFunctions] instance. + pub(crate) fn add_fn(&mut self, name: &str, func: F) { + self.map.insert(name.to_string(), Arc::new(func)); + } + + /// Get a function by name from the [SQLFunctions] instance. + pub(crate) fn get(&self, name: &str) -> Option<&Arc> { + self.map.get(name) + } +} + +impl SQLPlanner { + pub(crate) fn plan_table_function( + &self, + fn_name: &str, + args: &TableFunctionArgs, + alias: &Option, + ) -> SQLPlannerResult { + let fns = &SQL_TABLE_FUNCTIONS; + + let Some(func) = fns.get(fn_name) else { + unsupported_sql_err!("Function `{}` not found", fn_name); + }; + + let builder = func.plan(self, args)?; + let name = alias + .as_ref() + .map(|a| a.name.value.clone()) + .unwrap_or_else(|| fn_name.to_string()); + + Ok(Relation::new(builder, name)) + } +} + +pub(crate) trait SQLTableFunction: Send + Sync { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult; +} + +pub struct ReadDeltalakeFunction; + +#[cfg(feature = "python")] +impl SQLTableFunction for ReadDeltalakeFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + let (uri, io_config) = match args.args.as_slice() { + [uri] => (uri, None), + [uri, io_config] => { + let args = planner.parse_function_args(&[io_config.clone()], &["io_config"], 0)?; + let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; + + (uri, io_config) + } + _ => unsupported_sql_err!("Expected one or two arguments"), + }; + let uri = planner.plan_function_arg(uri)?; + + let Some(uri) = uri.as_literal().and_then(|lit| lit.as_str()) else { + unsupported_sql_err!("Expected a string literal for the first argument"); + }; + + LogicalPlanBuilder::delta_scan(uri, io_config, true).map_err(From::from) + } +} + +#[cfg(not(feature = "python"))] +impl SQLTableFunction for ReadDeltalakeFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + unsupported_sql_err!("`read_deltalake` function is not supported. Enable the `python` feature to use this function.") + } +} diff --git a/src/daft-sql/src/table_provider/read_parquet.rs b/src/daft-sql/src/table_provider/read_parquet.rs new file mode 100644 index 0000000000..36f84507a0 --- /dev/null +++ b/src/daft-sql/src/table_provider/read_parquet.rs @@ -0,0 +1,77 @@ +use daft_core::prelude::TimeUnit; +use daft_plan::{LogicalPlanBuilder, ParquetScanBuilder}; +use sqlparser::ast::TableFunctionArgs; + +use super::SQLTableFunction; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::SQLFunctionArguments, + modules::config::expr_to_iocfg, + planner::SQLPlanner, +}; + +pub(super) struct ReadParquetFunction; + +impl TryFrom for ParquetScanBuilder { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let glob_paths: String = args.try_get_positional(0)?.ok_or_else(|| { + PlannerError::invalid_operation("path is required for `read_parquet`") + })?; + let infer_schema = args.try_get_named("infer_schema")?.unwrap_or(true); + let coerce_int96_timestamp_unit = + args.try_get_named::("coerce_int96_timestamp_unit")?; + let coerce_int96_timestamp_unit: TimeUnit = coerce_int96_timestamp_unit + .as_deref() + .unwrap_or("nanoseconds") + .parse::() + .map_err(|_| { + PlannerError::invalid_argument("coerce_int96_timestamp_unit", "read_parquet") + })?; + let chunk_size = args.try_get_named("chunk_size")?; + let multithreaded = args.try_get_named("multithreaded")?.unwrap_or(true); + + let field_id_mapping = None; // TODO + let row_groups = None; // TODO + let schema = None; // TODO + let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; + + Ok(Self { + glob_paths: vec![glob_paths], + infer_schema, + coerce_int96_timestamp_unit, + field_id_mapping, + row_groups, + chunk_size, + io_config, + multithreaded, + schema, + }) + } +} + +impl SQLTableFunction for ReadParquetFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + let builder: ParquetScanBuilder = planner.plan_function_args( + args.args.as_slice(), + &[ + "infer_schema", + "coerce_int96_timestamp_unit", + "chunk_size", + "multithreaded", + // "schema", + // "field_id_mapping", + // "row_groups", + "io_config", + ], + 1, // 1 positional argument (path) + )?; + + builder.finish().map_err(From::from) + } +} diff --git a/tests/actor_pool/test_pyactor_pool.py b/tests/actor_pool/test_pyactor_pool.py index e95feec9ed..f34d91bd7b 100644 --- a/tests/actor_pool/test_pyactor_pool.py +++ b/tests/actor_pool/test_pyactor_pool.py @@ -69,5 +69,7 @@ def test_pyactor_pool_not_enough_resources(): assert isinstance(runner, PyRunner) with pytest.raises(RuntimeError, match=f"Requested {float(cpu_count + 1)} CPUs but found only"): - with runner.actor_pool_context("my-pool", ResourceRequest(num_cpus=1), cpu_count + 1, projection) as _: + with runner.actor_pool_context( + "my-pool", ResourceRequest(num_cpus=1), ResourceRequest(), cpu_count + 1, projection + ) as _: pass diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index 46db61d47e..ddd2c9b040 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -199,6 +199,26 @@ def test_parquet_write_multifile_with_partitioning(tmp_path, smaller_parquet_tar assert readback["y"] == [y % 2 for y in data["x"]] +def test_parquet_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path) + + assert len(output_files) == 3 + + read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict() + assert read_back == data + + +def test_parquet_partitioned_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path, partition_cols=["x"]) + + assert len(output_files) == 3 + + read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict() + assert read_back == data + + def test_csv_write(tmp_path): df = daft.read_csv(COOKBOOK_DATA_CSV) @@ -262,3 +282,23 @@ def test_empty_csv_write_with_partitioning(tmp_path): assert len(pd_df) == 1 assert len(pd_df._preview.preview_partition) == 1 + + +def test_csv_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path) + + assert len(output_files) == 3 + + read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict() + assert read_back == data + + +def test_csv_partitioned_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path, partition_cols=["x"]) + + assert len(output_files) == 3 + + read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict() + assert read_back == data diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index c43751f7d3..4816c767a4 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -15,18 +15,12 @@ import pytest import daft -from daft import context from daft.api_annotations import APITypeError from daft.dataframe import DataFrame from daft.datatype import DataType from daft.utils import pyarrow_supports_fixed_shape_tensor from tests.conftest import UuidType -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) @@ -1061,7 +1055,7 @@ def test_create_dataframe_parquet_mismatched_schemas_no_pushdown(): assert df.to_pydict() == {"x": [1, 2, 3, 4, None, None, None, None]} -def test_minio_parquet_read_mismatched_schemas_with_pushdown(minio_io_config): +def test_create_dataframe_parquet_read_mismatched_schemas_with_pushdown(): # When we read files, we infer schema from the first file # Then when we read subsequent files, we want to be able to read the data still but add nulls for columns # that don't exist @@ -1085,7 +1079,7 @@ def test_minio_parquet_read_mismatched_schemas_with_pushdown(minio_io_config): assert df.to_pydict() == {"x": [1, 2, 3, 4, 5, 6, 7, 8], "y": [1, 2, 3, 4, None, None, None, None]} -def test_minio_parquet_read_mismatched_schemas_with_pushdown_no_rows_read(minio_io_config): +def test_create_dataframe_parquet_read_mismatched_schemas_with_pushdown_no_rows_read(): # When we read files, we infer schema from the first file # Then when we read subsequent files, we want to be able to read the data still but add nulls for columns # that don't exist diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b0bdbf9df4..5e79acf698 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -53,6 +53,17 @@ def test_columns_after_join(make_df): assert set(joined_df2.schema().column_names()) == set(["A", "B"]) +def test_rename_join_keys_in_dataframe(make_df): + df1 = make_df({"A": [1, 2], "B": [2, 2]}) + + df2 = make_df({"A": [1, 2]}) + joined_df1 = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"]) + joined_df2 = df1.join(df2, left_on=["B", "A"], right_on=["A", "A"]) + + assert set(joined_df1.schema().column_names()) == set(["A", "B"]) + assert set(joined_df2.schema().column_names()) == set(["A", "B"]) + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) @pytest.mark.parametrize( "join_strategy", diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index 2572eb1adc..5ac09387e0 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -4,7 +4,9 @@ import pyarrow as pa import pytest +import daft from daft import col +from daft.context import get_context, set_planning_config from daft.datatype import DataType from daft.expressions import Expression from daft.expressions.testing import expr_structurally_equal @@ -13,6 +15,21 @@ from daft.udf import udf +@pytest.fixture(scope="function", params=[False, True]) +def actor_pool_enabled(request): + if request.param and get_context().daft_execution_config.enable_native_executor: + pytest.skip("Native executor does not support stateful UDFs") + + original_config = get_context().daft_planning_config + try: + set_planning_config( + config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=request.param) + ) + yield request.param + finally: + set_planning_config(config=original_config) + + def test_udf(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) @@ -30,8 +47,8 @@ def repeat_n(data, n): @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -41,18 +58,21 @@ def __init__(self): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + expr = RepeatN(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf_init_args(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_init_args(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -62,24 +82,27 @@ def __init__(self, initial_n: int = 2): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + expr = RepeatN(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} expr = RepeatN.with_init_args(initial_n=3)(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoofoo", "barbarbar", "bazbazbaz"]} @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf_init_args_no_default(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_init_args_no_default(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -89,18 +112,21 @@ def __init__(self, initial_n): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + with pytest.raises(ValueError, match="Cannot call StatefulUDF without initialization arguments."): RepeatN(col("a")) expr = RepeatN.with_init_args(initial_n=2)(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} -def test_class_udf_init_args_bad_args(): +def test_class_udf_init_args_bad_args(actor_pool_enabled): @udf(return_dtype=DataType.string()) class RepeatN: def __init__(self, initial_n): @@ -109,10 +135,37 @@ def __init__(self, initial_n): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"): RepeatN.with_init_args(wrong=5) +@pytest.mark.parametrize("concurrency", [1, 2, 4]) +@pytest.mark.parametrize("actor_pool_enabled", [True], indirect=True) +def test_stateful_udf_concurrency(concurrency, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) + + @udf(return_dtype=DataType.string(), batch_size=1) + class RepeatN: + def __init__(self): + self.n = 2 + + def __call__(self, data): + return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + + RepeatN = RepeatN.with_concurrency(concurrency) + + expr = RepeatN(col("a")) + field = expr._to_field(df.schema()) + assert field.name == "a" + assert field.dtype == DataType.string() + + result = df.select(expr) + assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} + + def test_udf_kwargs(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) @@ -208,8 +261,8 @@ def full_udf(e_arg, val, kwarg_val=None, kwarg_ex=None): full_udf() -def test_class_udf_initialization_error(): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_initialization_error(actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string()) class IdentityWithInitError: @@ -219,9 +272,16 @@ def __init__(self): def __call__(self, data): return data + if actor_pool_enabled: + IdentityWithInitError = IdentityWithInitError.with_concurrency(1) + expr = IdentityWithInitError(col("a")) - with pytest.raises(RuntimeError, match="UDF INIT ERROR"): - table.eval_expression_list([expr]) + if actor_pool_enabled: + with pytest.raises(Exception): + df.select(expr).collect() + else: + with pytest.raises(RuntimeError, match="UDF INIT ERROR"): + df.select(expr).collect() def test_udf_equality(): diff --git a/tests/io/delta_lake/test_table_write.py b/tests/io/delta_lake/test_table_write.py index 6519e85d0f..7a65d835cb 100644 --- a/tests/io/delta_lake/test_table_write.py +++ b/tests/io/delta_lake/test_table_write.py @@ -185,6 +185,21 @@ def test_deltalake_write_ignore(tmp_path): assert read_delta.to_pyarrow_table() == df1.to_arrow() +def test_deltalake_write_with_empty_partition(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") + path = tmp_path / "some_table" + df = daft.from_arrow(base_table).into_partitions(4) + result = df.write_deltalake(str(path)) + result = result.to_pydict() + assert result["operation"] == ["ADD", "ADD", "ADD"] + assert result["rows"] == [1, 1, 1] + + read_delta = deltalake.DeltaTable(str(path)) + expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) + assert df.schema() == expected_schema + assert read_delta.to_pyarrow_table() == base_table + + def check_equal_both_daft_and_delta_rs(df: daft.DataFrame, path: Path, sort_order: list[tuple[str, str]]): deltalake = pytest.importorskip("deltalake") @@ -256,6 +271,16 @@ def test_deltalake_write_partitioned_empty(tmp_path): check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")]) +def test_deltalake_write_partitioned_some_empty(tmp_path): + path = tmp_path / "some_table" + + df = daft.from_pydict({"int": [1, 2, 3, None], "string": ["foo", "foo", "bar", None]}).into_partitions(5) + + df.write_deltalake(str(path), partition_cols=["int"]) + + check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")]) + + def test_deltalake_write_partitioned_existing_table(tmp_path): path = tmp_path / "some_table" diff --git a/tests/io/iceberg/test_iceberg_writes.py b/tests/io/iceberg/test_iceberg_writes.py index aab68a0c5c..7f85447dba 100644 --- a/tests/io/iceberg/test_iceberg_writes.py +++ b/tests/io/iceberg/test_iceberg_writes.py @@ -209,6 +209,18 @@ def test_read_after_write_nested_fields(local_catalog): assert as_arrow == read_back.to_arrow() +def test_read_after_write_with_empty_partition(local_catalog): + df = daft.from_pydict({"x": [1, 2, 3]}).into_partitions(4) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD", "ADD", "ADD"] + assert as_dict["rows"] == [1, 1, 1] + read_back = daft.read_iceberg(table) + assert as_arrow == read_back.to_arrow() + + @pytest.fixture def complex_table() -> tuple[pa.Table, Schema]: table = pa.table( diff --git a/tests/io/test_parquet.py b/tests/io/test_parquet.py index c30ae8da1a..be2be34fca 100644 --- a/tests/io/test_parquet.py +++ b/tests/io/test_parquet.py @@ -12,7 +12,6 @@ import pytest import daft -from daft import context from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig from daft.datatype import DataType, TimeUnit from daft.expressions import col @@ -21,10 +20,6 @@ from ..integration.io.conftest import minio_create_bucket -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_11_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (11, 0, 0) PYARROW_GE_13_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (13, 0, 0) diff --git a/tests/series/test_cast.py b/tests/series/test_cast.py index cd0e74bdfa..b552aee20a 100644 --- a/tests/series/test_cast.py +++ b/tests/series/test_cast.py @@ -262,7 +262,7 @@ def test_cast_binary_to_fixed_size_binary(): assert casted.to_pylist() == [b"abc", b"def", None, b"bcd", None] -def test_cast_binary_to_fixed_size_binary_fails_with_variable_lengths(): +def test_cast_binary_to_fixed_size_binary_fails_with_variable_length(): data = [b"abc", b"def", None, b"bcd", None, b"long"] input = Series.from_pylist(data) @@ -368,7 +368,7 @@ def test_series_cast_python_to_list(dtype) -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [3, 3, 3, 3, 2, 2, None] + assert t.list.length().to_pylist() == [3, 3, 3, 3, 2, 2, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -397,7 +397,7 @@ def test_series_cast_python_to_fixed_size_list(dtype) -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [3, 3, 3, 3, 3, 3, None] + assert t.list.length().to_pylist() == [3, 3, 3, 3, 3, 3, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -426,7 +426,7 @@ def test_series_cast_python_to_embedding(dtype) -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [3, 3, 3, 3, 3, 3, None] + assert t.list.length().to_pylist() == [3, 3, 3, 3, 3, 3, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -448,7 +448,7 @@ def test_series_cast_list_to_embedding(dtype) -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [3, 3, 3, None] + assert t.list.length().to_pylist() == [3, 3, 3, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -473,7 +473,7 @@ def test_series_cast_numpy_to_image() -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [12, 27, None] + assert t.list.length().to_pylist() == [12, 27, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -495,7 +495,7 @@ def test_series_cast_numpy_to_image_infer_mode() -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [4, 27, None] + assert t.list.length().to_pylist() == [4, 27, None] pydata = t.to_arrow().to_pylist() assert pydata[0] == { @@ -536,7 +536,7 @@ def test_series_cast_python_to_fixed_shape_image() -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [12, 12, None] + assert t.list.length().to_pylist() == [12, 12, None] pydata = t.to_pylist() assert pydata[-1] is None diff --git a/tests/sql/test_exprs.py b/tests/sql/test_exprs.py new file mode 100644 index 0000000000..e3ae320094 --- /dev/null +++ b/tests/sql/test_exprs.py @@ -0,0 +1,69 @@ +import pytest + +import daft +from daft import col + + +def test_nested(): + df = daft.from_pydict( + { + "A": [1, 2, 3, 4], + "B": [1.5, 2.5, 3.5, 4.5], + "C": [True, True, False, False], + "D": [None, None, None, None], + } + ) + + actual = daft.sql("SELECT (A + 1) AS try_this FROM df").collect() + expected = df.select((daft.col("A") + 1).alias("try_this")).collect() + + assert actual.to_pydict() == expected.to_pydict() + + actual = daft.sql("SELECT *, (A + 1) AS try_this FROM df").collect() + expected = df.with_column("try_this", df["A"] + 1).collect() + + assert actual.to_pydict() == expected.to_pydict() + + +def test_hash_exprs(): + df = daft.from_pydict( + { + "a": ["foo", "bar", "baz", "qux"], + "ints": [1, 2, 3, 4], + "floats": [1.5, 2.5, 3.5, 4.5], + } + ) + + actual = ( + daft.sql(""" + SELECT + hash(a) as hash_a, + hash(a, 0) as hash_a_0, + hash(a, seed:=0) as hash_a_seed_0, + minhash(a, num_hashes:=10, ngram_size:= 100, seed:=10) as minhash_a, + minhash(a, num_hashes:=10, ngram_size:= 100) as minhash_a_no_seed, + FROM df + """) + .collect() + .to_pydict() + ) + + expected = ( + df.select( + col("a").hash().alias("hash_a"), + col("a").hash(0).alias("hash_a_0"), + col("a").hash(seed=0).alias("hash_a_seed_0"), + col("a").minhash(num_hashes=10, ngram_size=100, seed=10).alias("minhash_a"), + col("a").minhash(num_hashes=10, ngram_size=100).alias("minhash_a_no_seed"), + ) + .collect() + .to_pydict() + ) + + assert actual == expected + + with pytest.raises(Exception, match="Invalid arguments for minhash"): + daft.sql("SELECT minhash() as hash_a FROM df").collect() + + with pytest.raises(Exception, match="num_hashes is required"): + daft.sql("SELECT minhash(a) as hash_a FROM df").collect() diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index 12384e34e0..3802a3de20 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -199,3 +199,9 @@ def test_sql_function_raises_when_cant_get_frame(monkeypatch): monkeypatch.setattr("inspect.currentframe", lambda: None) with pytest.raises(DaftCoreException, match="Cannot get caller environment"): daft.sql("SELECT * FROM df") + + +def test_sql_multi_statement_sql_error(): + catalog = SQLCatalog({}) + with pytest.raises(Exception, match="one SQL statement allowed"): + daft.sql("SELECT * FROM df; SELECT * FROM df", catalog) diff --git a/tests/sql/test_table_funcs.py b/tests/sql/test_table_funcs.py new file mode 100644 index 0000000000..16fbf040c7 --- /dev/null +++ b/tests/sql/test_table_funcs.py @@ -0,0 +1,7 @@ +import daft + + +def test_sql_read_parquet(): + df = daft.sql("SELECT * FROM read_parquet('tests/assets/parquet-data/mvp.parquet')").collect() + expected = daft.read_parquet("tests/assets/parquet-data/mvp.parquet").collect() + assert df.to_pydict() == expected.to_pydict() diff --git a/tests/sql/test_utf8_exprs.py b/tests/sql/test_utf8_exprs.py new file mode 100644 index 0000000000..12b53a9ebc --- /dev/null +++ b/tests/sql/test_utf8_exprs.py @@ -0,0 +1,113 @@ +import daft +from daft import col + + +def test_utf8_exprs(): + df = daft.from_pydict( + { + "a": [ + "a", + "df_daft", + "foo", + "bar", + "baz", + "lorém", + "ipsum", + "dolor", + "sit", + "amet", + "😊", + "🌟", + "🎉", + "This is a longer with some words", + "THIS is ", + "", + ], + } + ) + + sql = """ + SELECT + ends_with(a, 'a') as ends_with_a, + starts_with(a, 'a') as starts_with_a, + contains(a, 'a') as contains_a, + split(a, ' ') as split_a, + regexp_match(a, 'ba.') as match_a, + regexp_extract(a, 'ba.') as extract_a, + regexp_extract_all(a, 'ba.') as extract_all_a, + regexp_replace(a, 'ba.', 'foo') as replace_a, + regexp_split(a, '\\s+') as regexp_split_a, + length(a) as length_a, + length_bytes(a) as length_bytes_a, + lower(a) as lower_a, + lstrip(a) as lstrip_a, + rstrip(a) as rstrip_a, + reverse(a) as reverse_a, + capitalize(a) as capitalize_a, + left(a, 4) as left_a, + right(a, 4) as right_a, + find(a, 'a') as find_a, + rpad(a, 10, '<') as rpad_a, + lpad(a, 10, '>') as lpad_a, + repeat(a, 2) as repeat_a, + a like 'a%' as like_a, + a ilike 'a%' as ilike_a, + substring(a, 1, 3) as substring_a, + count_matches(a, 'a') as count_matches_a_0, + count_matches(a, 'a', case_sensitive := true) as count_matches_a_1, + count_matches(a, 'a', case_sensitive := false, whole_words := false) as count_matches_a_2, + count_matches(a, 'a', case_sensitive := true, whole_words := true) as count_matches_a_3, + normalize(a) as normalize_a, + normalize(a, remove_punct:=true) as normalize_remove_punct_a, + normalize(a, remove_punct:=true, lowercase:=true) as normalize_remove_punct_lower_a, + normalize(a, remove_punct:=true, lowercase:=true, white_space:=true) as normalize_remove_punct_lower_ws_a, + tokenize_encode(a, 'r50k_base') as tokenize_encode_a, + tokenize_decode(tokenize_encode(a, 'r50k_base'), 'r50k_base') as tokenize_decode_a + FROM df + """ + actual = daft.sql(sql).collect() + expected = ( + df.select( + col("a").str.endswith("a").alias("ends_with_a"), + col("a").str.startswith("a").alias("starts_with_a"), + col("a").str.contains("a").alias("contains_a"), + col("a").str.split(" ").alias("split_a"), + col("a").str.match("ba.").alias("match_a"), + col("a").str.extract("ba.").alias("extract_a"), + col("a").str.extract_all("ba.").alias("extract_all_a"), + col("a").str.split(r"\s+", regex=True).alias("regexp_split_a"), + col("a").str.replace("ba.", "foo").alias("replace_a"), + col("a").str.length().alias("length_a"), + col("a").str.length_bytes().alias("length_bytes_a"), + col("a").str.lower().alias("lower_a"), + col("a").str.lstrip().alias("lstrip_a"), + col("a").str.rstrip().alias("rstrip_a"), + col("a").str.reverse().alias("reverse_a"), + col("a").str.capitalize().alias("capitalize_a"), + col("a").str.left(4).alias("left_a"), + col("a").str.right(4).alias("right_a"), + col("a").str.find("a").alias("find_a"), + col("a").str.rpad(10, "<").alias("rpad_a"), + col("a").str.lpad(10, ">").alias("lpad_a"), + col("a").str.repeat(2).alias("repeat_a"), + col("a").str.like("a%").alias("like_a"), + col("a").str.ilike("a%").alias("ilike_a"), + col("a").str.substr(1, 3).alias("substring_a"), + col("a").str.count_matches("a").alias("count_matches_a_0"), + col("a").str.count_matches("a", case_sensitive=True).alias("count_matches_a_1"), + col("a").str.count_matches("a", case_sensitive=False, whole_words=False).alias("count_matches_a_2"), + col("a").str.count_matches("a", case_sensitive=True, whole_words=True).alias("count_matches_a_3"), + col("a").str.normalize().alias("normalize_a"), + col("a").str.normalize(remove_punct=True).alias("normalize_remove_punct_a"), + col("a").str.normalize(remove_punct=True, lowercase=True).alias("normalize_remove_punct_lower_a"), + col("a") + .str.normalize(remove_punct=True, lowercase=True, white_space=True) + .alias("normalize_remove_punct_lower_ws_a"), + col("a").str.tokenize_encode("r50k_base").alias("tokenize_encode_a"), + col("a").str.tokenize_encode("r50k_base").str.tokenize_decode("r50k_base").alias("tokenize_decode_a"), + ) + .collect() + .to_pydict() + ) + actual = actual.to_pydict() + assert actual == expected diff --git a/tests/table/list/test_list_count_lengths.py b/tests/table/list/test_list_count_length.py similarity index 83% rename from tests/table/list/test_list_count_lengths.py rename to tests/table/list/test_list_count_length.py index 321219d6da..7e48f3088d 100644 --- a/tests/table/list/test_list_count_lengths.py +++ b/tests/table/list/test_list_count_length.py @@ -50,3 +50,12 @@ def test_fixed_list_count(fixed_table): result = fixed_table.eval_expression_list([col("col").list.count(CountMode.Null)]) assert result.to_pydict() == {"col": [0, 0, 1, 2, None]} + + +def test_list_length(fixed_table): + with pytest.warns(DeprecationWarning): + lengths_result = fixed_table.eval_expression_list([col("col").list.lengths()]) + length_result = fixed_table.eval_expression_list([col("col").list.length()]) + + assert lengths_result.to_pydict() == {"col": [2, 2, 2, 2, None]} + assert length_result.to_pydict() == {"col": [2, 2, 2, 2, None]} diff --git a/tests/test_resource_requests.py b/tests/test_resource_requests.py index 20a2aadf21..ec867aada7 100644 --- a/tests/test_resource_requests.py +++ b/tests/test_resource_requests.py @@ -8,7 +8,7 @@ import daft from daft import context, udf -from daft.context import get_context +from daft.context import get_context, set_planning_config from daft.daft import SystemInfo from daft.expressions import col from daft.internal.gpu import cuda_device_count @@ -127,6 +127,19 @@ def test_requesting_too_much_memory(): ### +@pytest.fixture(scope="function", params=[True]) +def enable_actor_pool(): + try: + original_config = get_context().daft_planning_config + + set_planning_config( + config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=True) + ) + yield + finally: + set_planning_config(config=original_config) + + @udf(return_dtype=daft.DataType.int64()) def assert_resources(c, num_cpus=None, num_gpus=None, memory=None): assigned_resources = ray.get_runtime_context().get_assigned_resources() @@ -141,6 +154,24 @@ def assert_resources(c, num_cpus=None, num_gpus=None, memory=None): return c +@udf(return_dtype=daft.DataType.int64()) +class AssertResourcesStateful: + def __init__(self): + pass + + def __call__(self, c, num_cpus=None, num_gpus=None, memory=None): + assigned_resources = ray.get_runtime_context().get_assigned_resources() + + for resource, ray_resource_key in [(num_cpus, "CPU"), (num_gpus, "GPU"), (memory, "memory")]: + if resource is None: + assert ray_resource_key not in assigned_resources or assigned_resources[ray_resource_key] is None + else: + assert ray_resource_key in assigned_resources + assert assigned_resources[ray_resource_key] == resource + + return c + + RAY_VERSION_LT_2 = int(ray.__version__.split(".")[0]) < 2 @@ -187,6 +218,51 @@ def test_with_column_folded_rayrunner(): df.collect() +@pytest.mark.skipif( + RAY_VERSION_LT_2, reason="The ray.get_runtime_context().get_assigned_resources() was only added in Ray >= 2.0" +) +@pytest.mark.skipif(get_context().runner_config.name not in {"ray"}, reason="requires RayRunner to be in use") +def test_with_column_rayrunner_class(enable_actor_pool): + assert_resources = AssertResourcesStateful.with_concurrency(1) + + df = daft.from_pydict(DATA).repartition(2) + + assert_resources_parametrized = assert_resources.override_options(num_cpus=1, memory_bytes=1_000_000, num_gpus=None) + df = df.with_column( + "resources_ok", + assert_resources_parametrized(col("id"), num_cpus=1, num_gpus=None, memory=1_000_000), + ) + + df.collect() + + +@pytest.mark.skipif( + RAY_VERSION_LT_2, reason="The ray.get_runtime_context().get_assigned_resources() was only added in Ray >= 2.0" +) +@pytest.mark.skipif(get_context().runner_config.name not in {"ray"}, reason="requires RayRunner to be in use") +def test_with_column_folded_rayrunner_class(enable_actor_pool): + assert_resources = AssertResourcesStateful.with_concurrency(1) + + df = daft.from_pydict(DATA).repartition(2) + + df = df.with_column( + "no_requests", + assert_resources(col("id"), num_cpus=1), # UDFs have 1 CPU by default + ) + + assert_resources_1 = assert_resources.override_options(num_cpus=1, memory_bytes=5_000_000) + df = df.with_column( + "more_memory_request", + assert_resources_1(col("id"), num_cpus=1, memory=5_000_000), + ) + assert_resources_2 = assert_resources.override_options(num_cpus=1, memory_bytes=None) + df = df.with_column( + "more_cpu_request", + assert_resources_2(col("id"), num_cpus=1), + ) + df.collect() + + ### # GPU tests - can only run if machine has a GPU ### diff --git a/tutorials/delta_lake/2-distributed-batch-inference.ipynb b/tutorials/delta_lake/2-distributed-batch-inference.ipynb index 8462a74e0f..478980b9d9 100644 --- a/tutorials/delta_lake/2-distributed-batch-inference.ipynb +++ b/tutorials/delta_lake/2-distributed-batch-inference.ipynb @@ -138,7 +138,7 @@ "\n", "# Prune data\n", "df = df.limit(NUM_ROWS)\n", - "df = df.where(df[\"object\"].list.lengths() == 1)" + "df = df.where(df[\"object\"].list.length() == 1)" ] }, {