Skip to content

Commit

Permalink
[CHORE] Expose read_sql partition bound strategy and default to min-m…
Browse files Browse the repository at this point in the history
…ax (#3246)

Currently, read_sql calculates partition bounds using the
`PERCENTILE_DISC` function. However, this function does not scale well
to large tables, as it is an expensive window + sort function. A better
alternative is to take samples, then estimate partition bounds, as
described in this issue:
#3245.

In the meantime, we should default to using the min-max calculations
instead, which was previously the fallback option.

---------

Co-authored-by: Colin Ho <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent 1a4d259 commit bd4e944
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 66 deletions.
11 changes: 8 additions & 3 deletions daft/io/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from daft.datatype import DataType
from daft.logical.builder import LogicalPlanBuilder
from daft.sql.sql_connection import SQLConnection
from daft.sql.sql_scan import SQLScanOperator
from daft.sql.sql_scan import PartitionBoundStrategy, SQLScanOperator

if TYPE_CHECKING:
from sqlalchemy.engine import Connection
Expand All @@ -22,6 +22,7 @@ def read_sql(
conn: Union[Callable[[], "Connection"], str],
partition_col: Optional[str] = None,
num_partitions: Optional[int] = None,
partition_bound_strategy: str = "min-max",
disable_pushdowns_to_sql: bool = False,
infer_schema: bool = True,
infer_schema_length: int = 10,
Expand All @@ -35,6 +36,8 @@ def read_sql(
partition_col (Optional[str]): Column to partition the data by, defaults to None
num_partitions (Optional[int]): Number of partitions to read the data into,
defaults to None, which will lets Daft determine the number of partitions.
If specified, `partition_col` must also be specified.
partition_bound_strategy (str): Strategy to determine partition bounds, either "min-max" or "percentile", defaults to "min-max"
disable_pushdowns_to_sql (bool): Whether to disable pushdowns to the SQL query, defaults to False
infer_schema (bool): Whether to turn on schema inference, defaults to True. If set to False, the schema parameter must be provided.
infer_schema_length (int): The number of rows to scan when inferring the schema, defaults to 10. If infer_schema is False, this parameter is ignored. Note that if Daft is able to use ConnectorX to infer the schema, this parameter is ignored as ConnectorX is an Arrow backed driver.
Expand All @@ -51,8 +54,9 @@ def read_sql(
#. Partitioning:
When `partition_col` is specified, the function partitions the query based on that column.
You can define `num_partitions` or leave it to Daft to decide.
Daft calculates the specified column's percentiles to determine partitions (e.g., for `num_partitions=3`, it uses the 33rd and 66th percentiles).
If the database or column type lacks percentile calculation support, Daft partitions the query using equal ranges between the column's minimum and maximum values.
Daft uses the `partition_bound_strategy` parameter to determine the partitioning strategy:
- `min_max`: Daft calculates the minimum and maximum values of the specified column, then partitions the query using equal ranges between the minimum and maximum values.
- `percentile`: Daft calculates the specified column's percentiles via a `PERCENTILE_DISC` function to determine partitions (e.g., for `num_partitions=3`, it uses the 33rd and 66th percentiles).
#. Execution:
Daft executes SQL queries using using `ConnectorX <https://sfu-db.github.io/connector-x/intro.html>`_ or `SQLAlchemy <https://docs.sqlalchemy.org/en/20/orm/quickstart.html#create-an-engine>`_,
Expand Down Expand Up @@ -113,6 +117,7 @@ def read_sql(
schema,
partition_col=partition_col,
num_partitions=num_partitions,
partition_bound_strategy=PartitionBoundStrategy.from_str(partition_bound_strategy),
)
handle = ScanOperatorHandle.from_python_scan_operator(sql_operator)
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)
Expand Down
123 changes: 63 additions & 60 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import math
import warnings
from enum import Enum, auto
from enum import Enum
from typing import TYPE_CHECKING, Any

from daft.context import get_context
Expand Down Expand Up @@ -31,8 +31,15 @@


class PartitionBoundStrategy(Enum):
PERCENTILE = auto()
MIN_MAX = auto()
PERCENTILE = "percentile"
MIN_MAX = "min-max"

@classmethod
def from_str(cls, value: str) -> PartitionBoundStrategy:
try:
return cls(value.lower())
except ValueError:
raise ValueError(f"Invalid PartitionBoundStrategy: {value}, must be either 'percentile' or 'min-max'")


class SQLScanOperator(ScanOperator):
Expand All @@ -47,6 +54,7 @@ def __init__(
schema: dict[str, DataType] | None,
partition_col: str | None = None,
num_partitions: int | None = None,
partition_bound_strategy: PartitionBoundStrategy | None = None,
) -> None:
super().__init__()
self.sql = sql
Expand All @@ -55,6 +63,7 @@ def __init__(
self._disable_pushdowns_to_sql = disable_pushdowns_to_sql
self._partition_col = partition_col
self._num_partitions = num_partitions
self._partition_bound_strategy = partition_bound_strategy
self._schema = self._attempt_schema_read(infer_schema, infer_schema_length, schema)

def schema(self) -> Schema:
Expand All @@ -79,7 +88,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
if num_scan_tasks == 1 or self._partition_col is None:
return self._single_scan_task(pushdowns, total_rows, total_size)

partition_bounds, strategy = self._get_partition_bounds_and_strategy(num_scan_tasks)
partition_bounds = self._get_partition_bounds(num_scan_tasks)
partition_bounds_sql = [lit(bound)._to_sql() for bound in partition_bounds]

if any(bound is None for bound in partition_bounds_sql):
Expand All @@ -88,7 +97,11 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
)
return self._single_scan_task(pushdowns, total_rows, total_size)

size_bytes = math.ceil(total_size / num_scan_tasks) if strategy == PartitionBoundStrategy.PERCENTILE else None
size_bytes = (
math.ceil(total_size / num_scan_tasks)
if self._partition_bound_strategy == PartitionBoundStrategy.PERCENTILE
else None
)
scan_tasks = []
for i in range(num_scan_tasks):
left_clause = f"{self._partition_col} >= {partition_bounds_sql[i]}"
Expand Down Expand Up @@ -159,39 +172,7 @@ def _get_num_rows(self) -> int:

return pa_table.column(0)[0].as_py()

def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]:
try:
# Try to get percentiles using percentile_disc.
# Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons.
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
# Use the OVER clause for SQL Server
over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""
percentile_sql = self.conn.construct_sql_query(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
limit=1,
)
pa_table = self.conn.execute_sql_query(percentile_sql)
return pa_table, PartitionBoundStrategy.PERCENTILE

except RuntimeError as e:
# If percentiles fails, use the min and max of the partition column
logger.info(
"Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s",
e,
)

min_max_sql = self.conn.construct_sql_query(
self.sql, projection=[f"MIN({self._partition_col}) as min", f"MAX({self._partition_col}) as max"]
)
pa_table = self.conn.execute_sql_query(min_max_sql)

return pa_table, PartitionBoundStrategy.MIN_MAX

def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]:
def _get_partition_bounds(self, num_scan_tasks: int) -> list[Any]:
if self._partition_col is None:
raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.")

Expand All @@ -203,35 +184,57 @@ def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[
f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning."
)

pa_table, strategy = self._attempt_partition_bounds_read(num_scan_tasks)
if self._partition_bound_strategy == PartitionBoundStrategy.PERCENTILE:
try:
# Try to get percentiles using percentile_disc.
# Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons.
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
# Use the OVER clause for SQL Server dialects
over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else ""
percentile_sql = self.conn.construct_sql_query(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
limit=1,
)
pa_table = self.conn.execute_sql_query(percentile_sql)

if pa_table.num_rows != 1:
raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.")
if pa_table.num_rows != 1:
raise RuntimeError(f"Expected 1 row, but got {pa_table.num_rows}.")

if strategy == PartitionBoundStrategy.PERCENTILE:
if pa_table.num_columns != num_scan_tasks + 1:
raise RuntimeError(
f"Failed to get partition bounds: expected {num_scan_tasks + 1} percentiles, but got {pa_table.num_columns}."
)
if pa_table.num_columns != num_scan_tasks + 1:
raise RuntimeError(f"Expected {num_scan_tasks + 1} percentiles, but got {pa_table.num_columns}.")

pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {f"bound_{i}" for i in range(num_scan_tasks + 1)}
bounds = [pydict[f"bound_{i}"][0] for i in range(num_scan_tasks + 1)]
pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {f"bound_{i}" for i in range(num_scan_tasks + 1)}
return [pydict[f"bound_{i}"][0] for i in range(num_scan_tasks + 1)]

elif strategy == PartitionBoundStrategy.MIN_MAX:
if pa_table.num_columns != 2:
raise RuntimeError(
f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}."
except Exception as e:
warnings.warn(
f"Failed to calculate partition bounds for read_sql using percentile strategy: {str(e)}. "
"Falling back to MIN_MAX strategy."
)
self._partition_bound_strategy = PartitionBoundStrategy.MIN_MAX

pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {"min", "max"}
min_val = pydict["min"][0]
max_val = pydict["max"][0]
range_size = (max_val - min_val) / num_scan_tasks
bounds = [min_val + range_size * i for i in range(num_scan_tasks)] + [max_val]
# Either MIN_MAX was explicitly specified or percentile calculation failed
min_max_sql = self.conn.construct_sql_query(
self.sql, projection=[f"MIN({self._partition_col}) as min", f"MAX({self._partition_col}) as max"]
)
pa_table = self.conn.execute_sql_query(min_max_sql)

return bounds, strategy
if pa_table.num_rows != 1:
raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.")
if pa_table.num_columns != 2:
raise RuntimeError(f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}.")

pydict = Table.from_arrow(pa_table).to_pydict()
assert pydict.keys() == {"min", "max"}
min_val = pydict["min"][0]
max_val = pydict["max"][0]
range_size = (max_val - min_val) / num_scan_tasks
return [min_val + range_size * i for i in range(num_scan_tasks)] + [max_val]

def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]:
return iter([self._construct_scan_task(pushdowns, num_rows=total_rows, size_bytes=math.ceil(total_size))])
Expand Down
14 changes: 11 additions & 3 deletions tests/integration/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,31 @@ def test_sql_create_dataframe_ok(test_db, pdf) -> None:

@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [2, 3, 4])
def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None:
@pytest.mark.parametrize("partition_bound_strategy", ["min-max", "percentile"])
def test_sql_partitioned_read(test_db, num_partitions, partition_bound_strategy, pdf) -> None:
row_size_bytes = daft.from_pandas(pdf).schema().estimate_row_size_bytes()
num_rows_per_partition = len(pdf) / num_partitions
with daft.execution_config_ctx(
read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition),
scan_tasks_min_size_bytes=0,
scan_tasks_max_size_bytes=0,
):
df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id")
df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
partition_col="id",
partition_bound_strategy=partition_bound_strategy,
)
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")


@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [1, 2, 3, 4])
@pytest.mark.parametrize("partition_col", ["id", "float_col", "date_col", "date_time_col"])
@pytest.mark.parametrize("partition_bound_strategy", ["min-max", "percentile"])
def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col(
test_db, num_partitions, partition_col, pdf
test_db, num_partitions, partition_col, partition_bound_strategy, pdf
) -> None:
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=0,
Expand All @@ -60,6 +67,7 @@ def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col(
test_db,
partition_col=partition_col,
num_partitions=num_partitions,
partition_bound_strategy=partition_bound_strategy,
)
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")
Expand Down

0 comments on commit bd4e944

Please sign in to comment.