From bd65b050a20e75913c4f643cbab47a7d994c59cc Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 13 Mar 2024 14:44:32 -0700 Subject: [PATCH] refactor and add limit pushdown --- daft/runners/partitioning.py | 2 - daft/sql/sql_reader.py | 17 +- daft/sql/sql_scan.py | 190 ++++++++++-------- daft/table/table_io.py | 28 ++- .../basic_concepts/read-and-write.rst | 17 ++ src/daft-dsl/src/expr.rs | 27 +-- src/daft-dsl/src/lit.rs | 23 ++- src/daft-micropartition/src/micropartition.rs | 6 - src/daft-micropartition/src/python.rs | 4 +- tests/integration/sql/test_sql.py | 6 + 10 files changed, 189 insertions(+), 131 deletions(-) diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 2c42dae7ca..35f3f0fb6d 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -77,11 +77,9 @@ class TableReadSQLOptions: """Options for parsing SQL tables Args: - predicate_sql: SQL predicate to apply to the table predicate_expression: Expression predicate to apply to the table """ - predicate_sql: str | None = None predicate_expression: Expression | None = None diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index 111af64622..367fdcaeaf 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import warnings from urllib.parse import urlparse import pyarrow as pa @@ -25,10 +26,15 @@ def __init__( self.predicate = predicate def read(self) -> pa.Table: - sql = self._construct_sql_query() - return self._execute_sql_query(sql) - - def _construct_sql_query(self) -> str: + try: + sql = self._construct_sql_query(apply_limit=True) + return self._execute_sql_query(sql) + except RuntimeError: + warnings.warn("Failed to execute the query with a limit, attempting to read the entire table.") + sql = self._construct_sql_query(apply_limit=False) + return self._execute_sql_query(sql) + + def _construct_sql_query(self, apply_limit: bool) -> str: clauses = [] if self.projection is not None: clauses.append(f"SELECT {', '.join(self.projection)}") @@ -40,7 +46,7 @@ def _construct_sql_query(self) -> str: if self.predicate is not None: clauses.append(f"WHERE {self.predicate}") - if self.limit is not None: + if self.limit is not None and apply_limit is True: clauses.append(f"LIMIT {self.limit}") return "\n".join(clauses) @@ -88,6 +94,7 @@ def _execute_sql_query_with_sqlalchemy(self, sql: str) -> pa.Table: rows = result.fetchall() pydict = {column_name: [row[i] for row in rows] for i, column_name in enumerate(result.keys())} + # TODO: Use type codes from cursor description to create pyarrow schema return pa.Table.from_pydict(pydict) except Exception as e: raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 7ac4502419..93e1b0ab56 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -4,6 +4,7 @@ import math import warnings from collections.abc import Iterator +from enum import Enum, auto from typing import Any from daft.context import get_context @@ -22,6 +23,11 @@ logger = logging.getLogger(__name__) +class PartitionBoundStrategy(Enum): + PERCENTILE = auto() + MIN_MAX = auto() + + class SQLScanOperator(ScanOperator): def __init__( self, @@ -39,25 +45,6 @@ def __init__( self._num_partitions = num_partitions self._schema = self._attempt_schema_read() - def _attempt_schema_read(self) -> Schema: - try: - pa_table = SQLReader(self.sql, self.url, limit=1).read() - schema = Schema.from_pyarrow_schema(pa_table.schema) - return schema - except Exception: - # If limit fails, try to read the entire table - pa_table = SQLReader(self.sql, self.url).read() - schema = Schema.from_pyarrow_schema(pa_table.schema) - return schema - - def _get_num_rows(self) -> int: - pa_table = SQLReader( - self.sql, - self.url, - projection=["COUNT(*)"], - ).read() - return pa_table.column(0)[0].as_py() - def schema(self) -> Schema: return self._schema @@ -73,67 +60,6 @@ def multiline_display(self) -> list[str]: f"Schema = {self._schema}", ] - def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], str]: - if self._partition_col is None: - raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.") - - if not ( - self._schema[self._partition_col].dtype._is_temporal_type() - or self._schema[self._partition_col].dtype._is_numeric_type() - ): - raise ValueError( - f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning." - ) - - try: - # try to get percentiles using percentile_cont - percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)] - pa_table = SQLReader( - self.sql, - self.url, - projection=[ - f"percentile_cont({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" - for i, percentile in enumerate(percentiles) - ], - ).read() - bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)] - return bounds, "percentile" - - except Exception as e: - # if the above fails, use the min and max of the partition column - logger.info("Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", e) - try: - pa_table = SQLReader( - self.sql, - self.url, - projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"], - ).read() - min_val = pa_table.column(0)[0].as_py() - max_val = pa_table.column(1)[0].as_py() - range_size = (max_val - min_val) / num_scan_tasks - return [min_val + range_size * i for i in range(1, num_scan_tasks)], "min_max" - - except Exception: - raise ValueError( - f"Failed to get partition bounds from {self._partition_col}. Please ensure that the column exists, and is numeric or temporal." - ) - - def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: - file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) - return iter( - [ - ScanTask.sql_scan_task( - url=self.url, - file_format=file_format_config, - schema=self._schema._schema, - num_rows=total_rows, - storage_config=self.storage_config, - size_bytes=math.ceil(total_size), - pushdowns=pushdowns, - ) - ] - ) - def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: total_rows = self._get_num_rows() estimate_row_size_bytes = self.schema().estimate_row_size_bytes() @@ -154,7 +80,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.") return self._single_scan_task(pushdowns, total_rows, total_size) - size_bytes = None if strategy == "min_max" else math.ceil(total_size / num_scan_tasks) + size_bytes = math.ceil(total_size / num_scan_tasks) if strategy == PartitionBoundStrategy.PERCENTILE else None scan_tasks = [] for i in range(num_scan_tasks): if i == 0: @@ -188,3 +114,105 @@ def can_absorb_limit(self) -> bool: def can_absorb_select(self) -> bool: return False + + def _attempt_schema_read(self) -> Schema: + pa_table = SQLReader(self.sql, self.url, limit=1).read() + schema = Schema.from_pyarrow_schema(pa_table.schema) + return schema + + def _get_num_rows(self) -> int: + pa_table = SQLReader( + self.sql, + self.url, + projection=["COUNT(*)"], + ).read() + + if pa_table.num_rows != 1: + raise RuntimeError( + "Failed to get the number of rows: COUNT(*) query returned an unexpected number of rows." + ) + if pa_table.num_columns != 1: + raise RuntimeError( + "Failed to get the number of rows: COUNT(*) query returned an unexpected number of columns." + ) + + return pa_table.column(0)[0].as_py() + + def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: + try: + # Try to get percentiles using percentile_cont + percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)] + pa_table = SQLReader( + self.sql, + self.url, + projection=[ + f"percentile_cont({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" + for i, percentile in enumerate(percentiles) + ], + ).read() + return pa_table, PartitionBoundStrategy.PERCENTILE + + except RuntimeError as e: + # If percentiles fails, use the min and max of the partition column + logger.info("Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", e) + + pa_table = SQLReader( + self.sql, + self.url, + projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"], + ).read() + return pa_table, PartitionBoundStrategy.MIN_MAX + + def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]: + if self._partition_col is None: + raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.") + + if not ( + self._schema[self._partition_col].dtype._is_temporal_type() + or self._schema[self._partition_col].dtype._is_numeric_type() + ): + raise ValueError( + f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning." + ) + + pa_table, strategy = self._attempt_partition_bounds_read(num_scan_tasks) + + if pa_table.num_rows != 1: + raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.") + + if strategy == PartitionBoundStrategy.PERCENTILE: + if pa_table.num_columns != num_scan_tasks - 1: + raise RuntimeError( + f"Failed to get partition bounds: expected {num_scan_tasks - 1} percentiles, but got {pa_table.num_columns}." + ) + + bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)] + + elif strategy == PartitionBoundStrategy.MIN_MAX: + if pa_table.num_columns != 2: + raise RuntimeError( + f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}." + ) + + min_val = pa_table.column(0)[0].as_py() + max_val = pa_table.column(1)[0].as_py() + range_size = (max_val - min_val) / num_scan_tasks + bounds = [min_val + range_size * i for i in range(1, num_scan_tasks)] + + return bounds, strategy + + def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) + return iter( + [ + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + num_rows=total_rows, + storage_config=self.storage_config, + size_bytes=math.ceil(total_size), + pushdowns=pushdowns, + ) + ] + ) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index b492fae967..5eb0fcb332 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -226,24 +226,38 @@ def read_sql( sql (str): SQL query to execute url (str): URL to the database schema (Schema): Daft schema to read the SQL query into + sql_options (TableReadSQLOptions, optional): SQL-specific configs to apply when reading the file + read_options (TableReadOptions, optional): Options for reading the file Returns: MicroPartition: MicroPartition from SQL query """ + + if sql_options.predicate_expression is not None: + # If the predicate can be translated to SQL, we can apply all pushdowns to the SQL query + predicate_sql = sql_options.predicate_expression._to_sql() + apply_pushdowns_to_sql = predicate_sql is not None + else: + # If we don't have a predicate, we can still apply the limit and projection to the SQL query + predicate_sql = None + apply_pushdowns_to_sql = True + pa_table = SQLReader( sql, url, - # TODO(Colin): Enable pushdowns - projection=None, - predicate=None, + limit=read_options.num_rows if apply_pushdowns_to_sql else None, + projection=read_options.column_names if apply_pushdowns_to_sql else None, + predicate=predicate_sql, ).read() mp = MicroPartition.from_arrow(pa_table) - if sql_options.predicate_expression is not None: - mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression])) + if len(mp) != 0 and apply_pushdowns_to_sql is False: + # If we have a non-empty table and we didn't apply pushdowns to SQL, we need to apply them in-memory + if sql_options.predicate_expression is not None: + mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression])) - if read_options.num_rows is not None: - mp = mp.head(read_options.num_rows) + if read_options.num_rows is not None: + mp = mp.head(read_options.num_rows) return _cast_table_to_schema(mp, read_options=read_options, schema=schema) diff --git a/docs/source/user_guide/basic_concepts/read-and-write.rst b/docs/source/user_guide/basic_concepts/read-and-write.rst index e4528cd5d8..2eb04cde22 100644 --- a/docs/source/user_guide/basic_concepts/read-and-write.rst +++ b/docs/source/user_guide/basic_concepts/read-and-write.rst @@ -72,6 +72,23 @@ For testing, or small datasets that fit in memory, you may also create DataFrame To learn more, consult the API documentation on :ref:`creating DataFrames from in-memory data structures `. +From Databases +^^^^^^^^^^^^^^ + +Daft can also read data from a variety of databases, including PostgreSQL, MySQL, Trino, and SQLite using the :func:`daft.read_sql` method. +In order to partition the data, you can specify a partition column, which will allow Daft to read the data in parallel. + +.. code:: python + + # Read from a PostgreSQL database + uri = "postgresql://user:password@host:port/database" + df = daft.read_sql(uri, "SELECT * FROM my_table") + + # Read with a partition column + df = daft.read_sql(uri, "SELECT * FROM my_table", partition_col="date") + +To learn more, consult the API documentation on :func:`daft.read_sql`. + Writing Data ------------ diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 2ebdb02424..6ad61c9434 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -574,28 +574,19 @@ impl Expr { fn to_sql_inner(expr: &Expr, buffer: &mut W) -> io::Result<()> { match expr { Expr::Column(name) => write!(buffer, "{}", name), - Expr::Literal(lit) => { - if let Some(s) = lit.to_sql() { - write!(buffer, "{}", s) - } else { - Err(io::Error::new( - io::ErrorKind::Other, - "Unsupported literal for SQL translation", - )) - } - } + Expr::Literal(lit) => lit.display_sql(buffer), Expr::Alias(inner, ..) => to_sql_inner(inner, buffer), Expr::BinaryOp { op, left, right } => { to_sql_inner(left, buffer)?; let op = match op { - Operator::Eq => "=".to_string(), - Operator::NotEq => "!=".to_string(), - Operator::Lt => "<".to_string(), - Operator::LtEq => "<=".to_string(), - Operator::Gt => ">".to_string(), - Operator::GtEq => ">=".to_string(), - Operator::And => "AND".to_string(), - Operator::Or => "OR".to_string(), + Operator::Eq => "=", + Operator::NotEq => "!=", + Operator::Lt => "<", + Operator::LtEq => "<=", + Operator::Gt => ">", + Operator::GtEq => ">=", + Operator::And => "AND", + Operator::Or => "OR", _ => { return Err(io::Error::new( io::ErrorKind::Other, diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index bb1f2c60af..34c6d2ce03 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -13,6 +13,7 @@ use daft_core::{ utils::display_table::{display_date32, display_series_literal, display_timestamp}, }; use serde::{Deserialize, Serialize}; +use std::io::{self, Write}; use std::{ fmt::{Display, Formatter, Result}, hash::{Hash, Hasher}, @@ -213,33 +214,37 @@ impl LiteralValue { result } - pub fn to_sql(&self) -> Option { + pub fn display_sql(&self, buffer: &mut W) -> io::Result<()> { use LiteralValue::*; + let display_sql_err = Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported literal for SQL translation", + )); match self { Null | Boolean(..) | Int32(..) | UInt32(..) | Int64(..) | UInt64(..) | Float64(..) => { - self.to_string().into() + write!(buffer, "{}", self) } - Utf8(val) => format!("'{}'", val).into(), - Date(val) => format!("DATE '{}'", display_date32(*val)).into(), + Utf8(val) => write!(buffer, "'{}'", val), + Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), Timestamp(val, tu, tz) => { // Different databases have different ways of handling timezones, so there's no reliable way to convert this to SQL. if tz.is_some() { - return None; + return display_sql_err; } // Note: Our display_timestamp function returns a string in the ISO 8601 format "YYYY-MM-DDTHH:MM:SS.fffff". // However, the ANSI SQL standard replaces the 'T' with a space. See: https://docs.actian.com/ingres/10s/index.html#page/SQLRef/Summary_of_ANSI_Date_2fTime_Data_Types.htm // While many databases support the 'T', some such as Trino, do not. So we replace the 'T' with a space here. // We also don't use the i64 unix timestamp directly, because different databases have different functions to convert unix timestamps to timestamps, e.g. Trino uses from_unixtime, while PostgreSQL uses to_timestamp. - format!( + write!( + buffer, "TIMESTAMP '{}'", display_timestamp(*val, tu, tz).replace('T', " ") ) - .into() } // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. - Decimal(..) | Series(..) | Time(..) | Binary(..) => None, + Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err, #[cfg(feature = "python")] - Python(..) => None, + Python(..) => display_sql_err, } } } diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 090c095053..a7ff7e3cf7 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -309,11 +309,6 @@ fn materialize_scan_task( .collect::>>() })?, FileFormatConfig::Database(DatabaseSourceConfig { sql }) => { - let predicate_sql = scan_task - .pushdowns - .filters - .as_ref() - .and_then(|p| p.to_sql()); let predicate_expr = scan_task .pushdowns .filters @@ -325,7 +320,6 @@ fn materialize_scan_task( py, sql, url, - predicate_sql.as_deref(), predicate_expr.clone(), scan_task.schema.clone().into(), scan_task diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 0137078e95..7ba80f7e89 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -787,12 +787,10 @@ pub(crate) fn read_parquet_into_py_table( .extract() } -#[allow(clippy::too_many_arguments)] pub(crate) fn read_sql_into_py_table( py: Python, sql: &str, url: &str, - predicate_sql: Option<&str>, predicate_expr: Option, schema: PySchema, include_columns: Option>, @@ -815,7 +813,7 @@ pub(crate) fn read_sql_into_py_table( let sql_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadSQLOptions"))? - .call1((predicate_sql, predicate_pyexpr))?; + .call1((predicate_pyexpr,))?; let read_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadOptions"))? diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index b0ac17b6a3..d664c2ef80 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -12,6 +12,12 @@ from tests.integration.sql.conftest import TEST_TABLE_NAME +@pytest.mark.integration() +def test_sql_show(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df.show() + + @pytest.mark.integration() def test_sql_create_dataframe_ok(test_db) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db)