diff --git a/daft/io/_sql.py b/daft/io/_sql.py index 2cb3e3520d..47b8710435 100644 --- a/daft/io/_sql.py +++ b/daft/io/_sql.py @@ -10,7 +10,7 @@ from daft.datatype import DataType from daft.logical.builder import LogicalPlanBuilder from daft.sql.sql_connection import SQLConnection -from daft.sql.sql_scan import SQLScanOperator +from daft.sql.sql_scan import PartitionBoundStrategy, SQLScanOperator if TYPE_CHECKING: from sqlalchemy.engine import Connection @@ -22,6 +22,7 @@ def read_sql( conn: Union[Callable[[], "Connection"], str], partition_col: Optional[str] = None, num_partitions: Optional[int] = None, + partition_bound_strategy: str = "min-max", disable_pushdowns_to_sql: bool = False, infer_schema: bool = True, infer_schema_length: int = 10, @@ -35,6 +36,8 @@ def read_sql( partition_col (Optional[str]): Column to partition the data by, defaults to None num_partitions (Optional[int]): Number of partitions to read the data into, defaults to None, which will lets Daft determine the number of partitions. + If specified, `partition_col` must also be specified. + partition_bound_strategy (str): Strategy to determine partition bounds, either "min-max" or "percentile", defaults to "min-max" disable_pushdowns_to_sql (bool): Whether to disable pushdowns to the SQL query, defaults to False infer_schema (bool): Whether to turn on schema inference, defaults to True. If set to False, the schema parameter must be provided. infer_schema_length (int): The number of rows to scan when inferring the schema, defaults to 10. If infer_schema is False, this parameter is ignored. Note that if Daft is able to use ConnectorX to infer the schema, this parameter is ignored as ConnectorX is an Arrow backed driver. @@ -51,8 +54,9 @@ def read_sql( #. Partitioning: When `partition_col` is specified, the function partitions the query based on that column. You can define `num_partitions` or leave it to Daft to decide. - Daft calculates the specified column's percentiles to determine partitions (e.g., for `num_partitions=3`, it uses the 33rd and 66th percentiles). - If the database or column type lacks percentile calculation support, Daft partitions the query using equal ranges between the column's minimum and maximum values. + Daft uses the `partition_bound_strategy` parameter to determine the partitioning strategy: + - `min_max`: Daft calculates the minimum and maximum values of the specified column, then partitions the query using equal ranges between the minimum and maximum values. + - `percentile`: Daft calculates the specified column's percentiles via a `PERCENTILE_DISC` function to determine partitions (e.g., for `num_partitions=3`, it uses the 33rd and 66th percentiles). #. Execution: Daft executes SQL queries using using `ConnectorX `_ or `SQLAlchemy `_, @@ -113,6 +117,7 @@ def read_sql( schema, partition_col=partition_col, num_partitions=num_partitions, + partition_bound_strategy=PartitionBoundStrategy.from_str(partition_bound_strategy), ) handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 4d3156ae80..892475e676 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -3,7 +3,7 @@ import logging import math import warnings -from enum import Enum, auto +from enum import Enum from typing import TYPE_CHECKING, Any from daft.context import get_context @@ -31,8 +31,15 @@ class PartitionBoundStrategy(Enum): - PERCENTILE = auto() - MIN_MAX = auto() + PERCENTILE = "percentile" + MIN_MAX = "min-max" + + @classmethod + def from_str(cls, value: str) -> PartitionBoundStrategy: + try: + return cls(value.lower()) + except ValueError: + raise ValueError(f"Invalid PartitionBoundStrategy: {value}, must be either 'percentile' or 'min-max'") class SQLScanOperator(ScanOperator): @@ -47,6 +54,7 @@ def __init__( schema: dict[str, DataType] | None, partition_col: str | None = None, num_partitions: int | None = None, + partition_bound_strategy: PartitionBoundStrategy | None = None, ) -> None: super().__init__() self.sql = sql @@ -55,6 +63,7 @@ def __init__( self._disable_pushdowns_to_sql = disable_pushdowns_to_sql self._partition_col = partition_col self._num_partitions = num_partitions + self._partition_bound_strategy = partition_bound_strategy self._schema = self._attempt_schema_read(infer_schema, infer_schema_length, schema) def schema(self) -> Schema: @@ -79,7 +88,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: if num_scan_tasks == 1 or self._partition_col is None: return self._single_scan_task(pushdowns, total_rows, total_size) - partition_bounds, strategy = self._get_partition_bounds_and_strategy(num_scan_tasks) + partition_bounds = self._get_partition_bounds(num_scan_tasks) partition_bounds_sql = [lit(bound)._to_sql() for bound in partition_bounds] if any(bound is None for bound in partition_bounds_sql): @@ -88,7 +97,11 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: ) return self._single_scan_task(pushdowns, total_rows, total_size) - size_bytes = math.ceil(total_size / num_scan_tasks) if strategy == PartitionBoundStrategy.PERCENTILE else None + size_bytes = ( + math.ceil(total_size / num_scan_tasks) + if self._partition_bound_strategy == PartitionBoundStrategy.PERCENTILE + else None + ) scan_tasks = [] for i in range(num_scan_tasks): left_clause = f"{self._partition_col} >= {partition_bounds_sql[i]}" @@ -159,39 +172,7 @@ def _get_num_rows(self) -> int: return pa_table.column(0)[0].as_py() - def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: - try: - # Try to get percentiles using percentile_disc. - # Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons. - percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] - # Use the OVER clause for SQL Server - over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else "" - percentile_sql = self.conn.construct_sql_query( - self.sql, - projection=[ - f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}" - for i, percentile in enumerate(percentiles) - ], - limit=1, - ) - pa_table = self.conn.execute_sql_query(percentile_sql) - return pa_table, PartitionBoundStrategy.PERCENTILE - - except RuntimeError as e: - # If percentiles fails, use the min and max of the partition column - logger.info( - "Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", - e, - ) - - min_max_sql = self.conn.construct_sql_query( - self.sql, projection=[f"MIN({self._partition_col}) as min", f"MAX({self._partition_col}) as max"] - ) - pa_table = self.conn.execute_sql_query(min_max_sql) - - return pa_table, PartitionBoundStrategy.MIN_MAX - - def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]: + def _get_partition_bounds(self, num_scan_tasks: int) -> list[Any]: if self._partition_col is None: raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.") @@ -203,35 +184,57 @@ def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[ f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning." ) - pa_table, strategy = self._attempt_partition_bounds_read(num_scan_tasks) + if self._partition_bound_strategy == PartitionBoundStrategy.PERCENTILE: + try: + # Try to get percentiles using percentile_disc. + # Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons. + percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] + # Use the OVER clause for SQL Server dialects + over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else "" + percentile_sql = self.conn.construct_sql_query( + self.sql, + projection=[ + f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}" + for i, percentile in enumerate(percentiles) + ], + limit=1, + ) + pa_table = self.conn.execute_sql_query(percentile_sql) - if pa_table.num_rows != 1: - raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.") + if pa_table.num_rows != 1: + raise RuntimeError(f"Expected 1 row, but got {pa_table.num_rows}.") - if strategy == PartitionBoundStrategy.PERCENTILE: - if pa_table.num_columns != num_scan_tasks + 1: - raise RuntimeError( - f"Failed to get partition bounds: expected {num_scan_tasks + 1} percentiles, but got {pa_table.num_columns}." - ) + if pa_table.num_columns != num_scan_tasks + 1: + raise RuntimeError(f"Expected {num_scan_tasks + 1} percentiles, but got {pa_table.num_columns}.") - pydict = Table.from_arrow(pa_table).to_pydict() - assert pydict.keys() == {f"bound_{i}" for i in range(num_scan_tasks + 1)} - bounds = [pydict[f"bound_{i}"][0] for i in range(num_scan_tasks + 1)] + pydict = Table.from_arrow(pa_table).to_pydict() + assert pydict.keys() == {f"bound_{i}" for i in range(num_scan_tasks + 1)} + return [pydict[f"bound_{i}"][0] for i in range(num_scan_tasks + 1)] - elif strategy == PartitionBoundStrategy.MIN_MAX: - if pa_table.num_columns != 2: - raise RuntimeError( - f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}." + except Exception as e: + warnings.warn( + f"Failed to calculate partition bounds for read_sql using percentile strategy: {str(e)}. " + "Falling back to MIN_MAX strategy." ) + self._partition_bound_strategy = PartitionBoundStrategy.MIN_MAX - pydict = Table.from_arrow(pa_table).to_pydict() - assert pydict.keys() == {"min", "max"} - min_val = pydict["min"][0] - max_val = pydict["max"][0] - range_size = (max_val - min_val) / num_scan_tasks - bounds = [min_val + range_size * i for i in range(num_scan_tasks)] + [max_val] + # Either MIN_MAX was explicitly specified or percentile calculation failed + min_max_sql = self.conn.construct_sql_query( + self.sql, projection=[f"MIN({self._partition_col}) as min", f"MAX({self._partition_col}) as max"] + ) + pa_table = self.conn.execute_sql_query(min_max_sql) - return bounds, strategy + if pa_table.num_rows != 1: + raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.") + if pa_table.num_columns != 2: + raise RuntimeError(f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}.") + + pydict = Table.from_arrow(pa_table).to_pydict() + assert pydict.keys() == {"min", "max"} + min_val = pydict["min"][0] + max_val = pydict["max"][0] + range_size = (max_val - min_val) / num_scan_tasks + return [min_val + range_size * i for i in range(num_scan_tasks)] + [max_val] def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: return iter([self._construct_scan_task(pushdowns, num_rows=total_rows, size_bytes=math.ceil(total_size))]) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 7983be00c7..85484419f0 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -32,7 +32,8 @@ def test_sql_create_dataframe_ok(test_db, pdf) -> None: @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [2, 3, 4]) -def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None: +@pytest.mark.parametrize("partition_bound_strategy", ["min-max", "percentile"]) +def test_sql_partitioned_read(test_db, num_partitions, partition_bound_strategy, pdf) -> None: row_size_bytes = daft.from_pandas(pdf).schema().estimate_row_size_bytes() num_rows_per_partition = len(pdf) / num_partitions with daft.execution_config_ctx( @@ -40,7 +41,12 @@ def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None: scan_tasks_min_size_bytes=0, scan_tasks_max_size_bytes=0, ): - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id") + df = daft.read_sql( + f"SELECT * FROM {TEST_TABLE_NAME}", + test_db, + partition_col="id", + partition_bound_strategy=partition_bound_strategy, + ) assert df.num_partitions() == num_partitions assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id") @@ -48,8 +54,9 @@ def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None: @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) @pytest.mark.parametrize("partition_col", ["id", "float_col", "date_col", "date_time_col"]) +@pytest.mark.parametrize("partition_bound_strategy", ["min-max", "percentile"]) def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( - test_db, num_partitions, partition_col, pdf + test_db, num_partitions, partition_col, partition_bound_strategy, pdf ) -> None: with daft.execution_config_ctx( scan_tasks_min_size_bytes=0, @@ -60,6 +67,7 @@ def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( test_db, partition_col=partition_col, num_partitions=num_partitions, + partition_bound_strategy=partition_bound_strategy, ) assert df.num_partitions() == num_partitions assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")