Skip to content

Commit

Permalink
refactor and add limit pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-ho committed Mar 13, 2024
1 parent 0a85439 commit bd65b05
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 131 deletions.
2 changes: 0 additions & 2 deletions daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ class TableReadSQLOptions:
"""Options for parsing SQL tables
Args:
predicate_sql: SQL predicate to apply to the table
predicate_expression: Expression predicate to apply to the table
"""

predicate_sql: str | None = None
predicate_expression: Expression | None = None


Expand Down
17 changes: 12 additions & 5 deletions daft/sql/sql_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import warnings
from urllib.parse import urlparse

import pyarrow as pa
Expand All @@ -25,10 +26,15 @@ def __init__(
self.predicate = predicate

Check warning on line 26 in daft/sql/sql_reader.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_reader.py#L22-L26

Added lines #L22 - L26 were not covered by tests

def read(self) -> pa.Table:
sql = self._construct_sql_query()
return self._execute_sql_query(sql)

def _construct_sql_query(self) -> str:
try:
sql = self._construct_sql_query(apply_limit=True)
return self._execute_sql_query(sql)
except RuntimeError:
warnings.warn("Failed to execute the query with a limit, attempting to read the entire table.")
sql = self._construct_sql_query(apply_limit=False)
return self._execute_sql_query(sql)

Check warning on line 35 in daft/sql/sql_reader.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_reader.py#L29-L35

Added lines #L29 - L35 were not covered by tests

def _construct_sql_query(self, apply_limit: bool) -> str:
clauses = []
if self.projection is not None:
clauses.append(f"SELECT {', '.join(self.projection)}")

Check warning on line 40 in daft/sql/sql_reader.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_reader.py#L38-L40

Added lines #L38 - L40 were not covered by tests
Expand All @@ -40,7 +46,7 @@ def _construct_sql_query(self) -> str:
if self.predicate is not None:
clauses.append(f"WHERE {self.predicate}")

Check warning on line 47 in daft/sql/sql_reader.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_reader.py#L46-L47

Added lines #L46 - L47 were not covered by tests

if self.limit is not None:
if self.limit is not None and apply_limit is True:
clauses.append(f"LIMIT {self.limit}")

Check warning on line 50 in daft/sql/sql_reader.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_reader.py#L49-L50

Added lines #L49 - L50 were not covered by tests

return "\n".join(clauses)

Check warning on line 52 in daft/sql/sql_reader.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_reader.py#L52

Added line #L52 was not covered by tests
Expand Down Expand Up @@ -88,6 +94,7 @@ def _execute_sql_query_with_sqlalchemy(self, sql: str) -> pa.Table:
rows = result.fetchall()
pydict = {column_name: [row[i] for row in rows] for i, column_name in enumerate(result.keys())}

Check warning on line 95 in daft/sql/sql_reader.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_reader.py#L90-L95

Added lines #L90 - L95 were not covered by tests

# TODO: Use type codes from cursor description to create pyarrow schema
return pa.Table.from_pydict(pydict)
except Exception as e:
raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e

Check warning on line 100 in daft/sql/sql_reader.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_reader.py#L98-L100

Added lines #L98 - L100 were not covered by tests
190 changes: 109 additions & 81 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import warnings
from collections.abc import Iterator
from enum import Enum, auto
from typing import Any

from daft.context import get_context
Expand All @@ -22,6 +23,11 @@
logger = logging.getLogger(__name__)


class PartitionBoundStrategy(Enum):
PERCENTILE = auto()
MIN_MAX = auto()


class SQLScanOperator(ScanOperator):
def __init__(
self,
Expand All @@ -39,25 +45,6 @@ def __init__(
self._num_partitions = num_partitions
self._schema = self._attempt_schema_read()

Check warning on line 46 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L40-L46

Added lines #L40 - L46 were not covered by tests

def _attempt_schema_read(self) -> Schema:
try:
pa_table = SQLReader(self.sql, self.url, limit=1).read()
schema = Schema.from_pyarrow_schema(pa_table.schema)
return schema
except Exception:
# If limit fails, try to read the entire table
pa_table = SQLReader(self.sql, self.url).read()
schema = Schema.from_pyarrow_schema(pa_table.schema)
return schema

def _get_num_rows(self) -> int:
pa_table = SQLReader(
self.sql,
self.url,
projection=["COUNT(*)"],
).read()
return pa_table.column(0)[0].as_py()

def schema(self) -> Schema:
return self._schema

Check warning on line 49 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L49

Added line #L49 was not covered by tests

Expand All @@ -73,67 +60,6 @@ def multiline_display(self) -> list[str]:
f"Schema = {self._schema}",
]

def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], str]:
if self._partition_col is None:
raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.")

if not (
self._schema[self._partition_col].dtype._is_temporal_type()
or self._schema[self._partition_col].dtype._is_numeric_type()
):
raise ValueError(
f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning."
)

try:
# try to get percentiles using percentile_cont
percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)]
pa_table = SQLReader(
self.sql,
self.url,
projection=[
f"percentile_cont({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
).read()
bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)]
return bounds, "percentile"

except Exception as e:
# if the above fails, use the min and max of the partition column
logger.info("Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", e)
try:
pa_table = SQLReader(
self.sql,
self.url,
projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"],
).read()
min_val = pa_table.column(0)[0].as_py()
max_val = pa_table.column(1)[0].as_py()
range_size = (max_val - min_val) / num_scan_tasks
return [min_val + range_size * i for i in range(1, num_scan_tasks)], "min_max"

except Exception:
raise ValueError(
f"Failed to get partition bounds from {self._partition_col}. Please ensure that the column exists, and is numeric or temporal."
)

def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]:
file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql))
return iter(
[
ScanTask.sql_scan_task(
url=self.url,
file_format=file_format_config,
schema=self._schema._schema,
num_rows=total_rows,
storage_config=self.storage_config,
size_bytes=math.ceil(total_size),
pushdowns=pushdowns,
)
]
)

def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
total_rows = self._get_num_rows()
estimate_row_size_bytes = self.schema().estimate_row_size_bytes()
Expand All @@ -154,7 +80,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.")
return self._single_scan_task(pushdowns, total_rows, total_size)

Check warning on line 81 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L79-L81

Added lines #L79 - L81 were not covered by tests

size_bytes = None if strategy == "min_max" else math.ceil(total_size / num_scan_tasks)
size_bytes = math.ceil(total_size / num_scan_tasks) if strategy == PartitionBoundStrategy.PERCENTILE else None
scan_tasks = []
for i in range(num_scan_tasks):
if i == 0:
Expand Down Expand Up @@ -188,3 +114,105 @@ def can_absorb_limit(self) -> bool:

def can_absorb_select(self) -> bool:
return False

Check warning on line 116 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L116

Added line #L116 was not covered by tests

def _attempt_schema_read(self) -> Schema:
pa_table = SQLReader(self.sql, self.url, limit=1).read()
schema = Schema.from_pyarrow_schema(pa_table.schema)
return schema

Check warning on line 121 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L119-L121

Added lines #L119 - L121 were not covered by tests

def _get_num_rows(self) -> int:
pa_table = SQLReader(

Check warning on line 124 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L124

Added line #L124 was not covered by tests
self.sql,
self.url,
projection=["COUNT(*)"],
).read()

if pa_table.num_rows != 1:
raise RuntimeError(

Check warning on line 131 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L130-L131

Added lines #L130 - L131 were not covered by tests
"Failed to get the number of rows: COUNT(*) query returned an unexpected number of rows."
)
if pa_table.num_columns != 1:
raise RuntimeError(

Check warning on line 135 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L134-L135

Added lines #L134 - L135 were not covered by tests
"Failed to get the number of rows: COUNT(*) query returned an unexpected number of columns."
)

return pa_table.column(0)[0].as_py()

Check warning on line 139 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L139

Added line #L139 was not covered by tests

def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]:
try:

Check warning on line 142 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L142

Added line #L142 was not covered by tests
# Try to get percentiles using percentile_cont
percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)]
pa_table = SQLReader(

Check warning on line 145 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L144-L145

Added lines #L144 - L145 were not covered by tests
self.sql,
self.url,
projection=[
f"percentile_cont({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
).read()
return pa_table, PartitionBoundStrategy.PERCENTILE

Check warning on line 153 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L153

Added line #L153 was not covered by tests

except RuntimeError as e:

Check warning on line 155 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L155

Added line #L155 was not covered by tests
# If percentiles fails, use the min and max of the partition column
logger.info("Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", e)

Check warning on line 157 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L157

Added line #L157 was not covered by tests

pa_table = SQLReader(

Check warning on line 159 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L159

Added line #L159 was not covered by tests
self.sql,
self.url,
projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"],
).read()
return pa_table, PartitionBoundStrategy.MIN_MAX

Check warning on line 164 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L164

Added line #L164 was not covered by tests

def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]:
if self._partition_col is None:
raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.")

Check warning on line 168 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L167-L168

Added lines #L167 - L168 were not covered by tests

if not (

Check warning on line 170 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L170

Added line #L170 was not covered by tests
self._schema[self._partition_col].dtype._is_temporal_type()
or self._schema[self._partition_col].dtype._is_numeric_type()
):
raise ValueError(

Check warning on line 174 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L174

Added line #L174 was not covered by tests
f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning."
)

pa_table, strategy = self._attempt_partition_bounds_read(num_scan_tasks)

Check warning on line 178 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L178

Added line #L178 was not covered by tests

if pa_table.num_rows != 1:
raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.")

Check warning on line 181 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L180-L181

Added lines #L180 - L181 were not covered by tests

if strategy == PartitionBoundStrategy.PERCENTILE:
if pa_table.num_columns != num_scan_tasks - 1:
raise RuntimeError(

Check warning on line 185 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L183-L185

Added lines #L183 - L185 were not covered by tests
f"Failed to get partition bounds: expected {num_scan_tasks - 1} percentiles, but got {pa_table.num_columns}."
)

bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)]

Check warning on line 189 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L189

Added line #L189 was not covered by tests

elif strategy == PartitionBoundStrategy.MIN_MAX:
if pa_table.num_columns != 2:
raise RuntimeError(

Check warning on line 193 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L191-L193

Added lines #L191 - L193 were not covered by tests
f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}."
)

min_val = pa_table.column(0)[0].as_py()
max_val = pa_table.column(1)[0].as_py()
range_size = (max_val - min_val) / num_scan_tasks
bounds = [min_val + range_size * i for i in range(1, num_scan_tasks)]

Check warning on line 200 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L197-L200

Added lines #L197 - L200 were not covered by tests

return bounds, strategy

Check warning on line 202 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L202

Added line #L202 was not covered by tests

def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]:
file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql))
return iter(

Check warning on line 206 in daft/sql/sql_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/sql/sql_scan.py#L205-L206

Added lines #L205 - L206 were not covered by tests
[
ScanTask.sql_scan_task(
url=self.url,
file_format=file_format_config,
schema=self._schema._schema,
num_rows=total_rows,
storage_config=self.storage_config,
size_bytes=math.ceil(total_size),
pushdowns=pushdowns,
)
]
)
28 changes: 21 additions & 7 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,24 +226,38 @@ def read_sql(
sql (str): SQL query to execute
url (str): URL to the database
schema (Schema): Daft schema to read the SQL query into
sql_options (TableReadSQLOptions, optional): SQL-specific configs to apply when reading the file
read_options (TableReadOptions, optional): Options for reading the file
Returns:
MicroPartition: MicroPartition from SQL query
"""

if sql_options.predicate_expression is not None:

Check warning on line 236 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L236

Added line #L236 was not covered by tests
# If the predicate can be translated to SQL, we can apply all pushdowns to the SQL query
predicate_sql = sql_options.predicate_expression._to_sql()
apply_pushdowns_to_sql = predicate_sql is not None

Check warning on line 239 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L238-L239

Added lines #L238 - L239 were not covered by tests
else:
# If we don't have a predicate, we can still apply the limit and projection to the SQL query
predicate_sql = None
apply_pushdowns_to_sql = True

Check warning on line 243 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L242-L243

Added lines #L242 - L243 were not covered by tests

pa_table = SQLReader(

Check warning on line 245 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L245

Added line #L245 was not covered by tests
sql,
url,
# TODO(Colin): Enable pushdowns
projection=None,
predicate=None,
limit=read_options.num_rows if apply_pushdowns_to_sql else None,
projection=read_options.column_names if apply_pushdowns_to_sql else None,
predicate=predicate_sql,
).read()
mp = MicroPartition.from_arrow(pa_table)

Check warning on line 252 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L252

Added line #L252 was not covered by tests

if sql_options.predicate_expression is not None:
mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression]))
if len(mp) != 0 and apply_pushdowns_to_sql is False:

Check warning on line 254 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L254

Added line #L254 was not covered by tests
# If we have a non-empty table and we didn't apply pushdowns to SQL, we need to apply them in-memory
if sql_options.predicate_expression is not None:
mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression]))

Check warning on line 257 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L256-L257

Added lines #L256 - L257 were not covered by tests

if read_options.num_rows is not None:
mp = mp.head(read_options.num_rows)
if read_options.num_rows is not None:
mp = mp.head(read_options.num_rows)

Check warning on line 260 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L259-L260

Added lines #L259 - L260 were not covered by tests

return _cast_table_to_schema(mp, read_options=read_options, schema=schema)

Check warning on line 262 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L262

Added line #L262 was not covered by tests

Expand Down
17 changes: 17 additions & 0 deletions docs/source/user_guide/basic_concepts/read-and-write.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ For testing, or small datasets that fit in memory, you may also create DataFrame
To learn more, consult the API documentation on :ref:`creating DataFrames from in-memory data structures <df-io-in-memory>`.

From Databases
^^^^^^^^^^^^^^

Daft can also read data from a variety of databases, including PostgreSQL, MySQL, Trino, and SQLite using the :func:`daft.read_sql` method.
In order to partition the data, you can specify a partition column, which will allow Daft to read the data in parallel.

.. code:: python
# Read from a PostgreSQL database
uri = "postgresql://user:password@host:port/database"
df = daft.read_sql(uri, "SELECT * FROM my_table")
# Read with a partition column
df = daft.read_sql(uri, "SELECT * FROM my_table", partition_col="date")
To learn more, consult the API documentation on :func:`daft.read_sql`.

Writing Data
------------

Expand Down
27 changes: 9 additions & 18 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,28 +574,19 @@ impl Expr {
fn to_sql_inner<W: Write>(expr: &Expr, buffer: &mut W) -> io::Result<()> {
match expr {
Expr::Column(name) => write!(buffer, "{}", name),
Expr::Literal(lit) => {
if let Some(s) = lit.to_sql() {
write!(buffer, "{}", s)
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"Unsupported literal for SQL translation",
))
}
}
Expr::Literal(lit) => lit.display_sql(buffer),
Expr::Alias(inner, ..) => to_sql_inner(inner, buffer),
Expr::BinaryOp { op, left, right } => {
to_sql_inner(left, buffer)?;
let op = match op {
Operator::Eq => "=".to_string(),
Operator::NotEq => "!=".to_string(),
Operator::Lt => "<".to_string(),
Operator::LtEq => "<=".to_string(),
Operator::Gt => ">".to_string(),
Operator::GtEq => ">=".to_string(),
Operator::And => "AND".to_string(),
Operator::Or => "OR".to_string(),
Operator::Eq => "=",
Operator::NotEq => "!=",
Operator::Lt => "<",
Operator::LtEq => "<=",
Operator::Gt => ">",
Operator::GtEq => ">=",
Operator::And => "AND",
Operator::Or => "OR",
_ => {
return Err(io::Error::new(
io::ErrorKind::Other,
Expand Down
Loading

0 comments on commit bd65b05

Please sign in to comment.