From dbad9a302ebc3d78a67c71a057b3835b57a9bbdc Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 22 Feb 2024 12:22:26 -0800 Subject: [PATCH 01/30] init --- daft/__init__.py | 2 + daft/daft.pyi | 31 +++- daft/iceberg/iceberg_scan.py | 1 + daft/io/__init__.py | 2 + daft/io/_sql.py | 60 ++++++++ daft/runners/partitioning.py | 13 ++ daft/sql/__init__.py | 0 daft/sql/sql_scan.py | 135 ++++++++++++++++++ daft/table/table_io.py | 54 ++++++- src/daft-micropartition/src/micropartition.rs | 39 ++++- src/daft-micropartition/src/python.rs | 50 +++++++ src/daft-plan/src/lib.rs | 3 + src/daft-plan/src/planner.rs | 3 +- src/daft-scan/src/file_format.rs | 53 +++++++ src/daft-scan/src/glob.rs | 5 + src/daft-scan/src/lib.rs | 36 ++++- src/daft-scan/src/python.rs | 27 ++++ src/daft-scan/src/scan_task_iters.rs | 8 +- tests/dataframe/test_creation.py | 41 ++++++ 19 files changed, 552 insertions(+), 11 deletions(-) create mode 100644 daft/io/_sql.py create mode 100644 daft/sql/__init__.py create mode 100644 daft/sql/sql_scan.py diff --git a/daft/__init__.py b/daft/__init__.py index f2470e7a31..9069810c42 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -76,6 +76,7 @@ def get_build_type() -> str: read_iceberg, read_json, read_parquet, + read_sql, ) from daft.series import Series from daft.udf import udf @@ -94,6 +95,7 @@ def get_build_type() -> str: "read_parquet", "read_iceberg", "read_delta_lake", + "read_sql", "DataCatalogType", "DataCatalogTable", "DataFrame", diff --git a/daft/daft.pyi b/daft/daft.pyi index c66d34e83d..55e2f81320 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -217,12 +217,23 @@ class JsonSourceConfig: chunk_size: int | None = None, ): ... +class DatabaseSourceConfig: + """ + Configuration of a database data source. + """ + + sql: str + limit: int | None + offset: int | None + + def __init__(self, sql: str, limit: int | None = None, offset: int | None = None): ... + class FileFormatConfig: """ Configuration for parsing a particular file format (Parquet, CSV, JSON). """ - config: ParquetSourceConfig | CsvSourceConfig | JsonSourceConfig + config: ParquetSourceConfig | CsvSourceConfig | JsonSourceConfig | DatabaseSourceConfig @staticmethod def from_parquet_config(config: ParquetSourceConfig) -> FileFormatConfig: @@ -242,6 +253,12 @@ class FileFormatConfig: Create a JSON file format config. """ ... + @staticmethod + def from_database_config(config: DatabaseSourceConfig) -> FileFormatConfig: + """ + Create a database file format config. + """ + ... def file_format(self) -> FileFormat: """ Get the file format for this config. @@ -583,6 +600,18 @@ class ScanTask: Create a Catalog Scan Task """ ... + @staticmethod + def sql_scan_task( + url: str, + file_format: FileFormatConfig, + schema: PySchema, + storage_config: StorageConfig, + pushdowns: Pushdowns | None, + ) -> ScanTask: + """ + Create a SQL Scan Task + """ + ... class ScanOperatorHandle: """ diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index 4b583bcce0..54eb6f7da0 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -170,6 +170,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: continue rows_left -= record_count scan_tasks.append(st) + return iter(scan_tasks) def can_absorb_filter(self) -> bool: diff --git a/daft/io/__init__.py b/daft/io/__init__.py index 18330faaf3..ce21b86cf5 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -14,6 +14,7 @@ from daft.io._iceberg import read_iceberg from daft.io._json import read_json from daft.io._parquet import read_parquet +from daft.io._sql import read_sql from daft.io.catalog import DataCatalogTable, DataCatalogType from daft.io.file_path import from_glob_path @@ -39,6 +40,7 @@ def _set_linux_cert_paths(): "read_parquet", "read_iceberg", "read_delta_lake", + "read_sql", "IOConfig", "S3Config", "AzureConfig", diff --git a/daft/io/_sql.py b/daft/io/_sql.py new file mode 100644 index 0000000000..803093294a --- /dev/null +++ b/daft/io/_sql.py @@ -0,0 +1,60 @@ +# isort: dont-add-import: from __future__ import annotations + + +from daft import context +from daft.api_annotations import PublicAPI +from daft.daft import ( + NativeStorageConfig, + PythonStorageConfig, + ScanOperatorHandle, + StorageConfig, +) +from daft.dataframe import DataFrame +from daft.logical.builder import LogicalPlanBuilder +from daft.sql.sql_scan import SQLScanOperator + + +def native_downloader_available(url: str) -> bool: + # TODO: We should be able to support native downloads via ConnectorX for compatible databases + return False + + +@PublicAPI +def read_sql( + sql: str, + url: str, + use_native_downloader: bool = False, + # schema_hints: Optional[Dict[str, DataType]] = None, +) -> DataFrame: + """Creates a DataFrame from a SQL query + + Example: + >>> def create_connection(): + return sqlite3.connect("example.db") + >>> df = daft.read_sql("SELECT * FROM my_table", create_connection) + + Args: + sql (str): SQL query to execute + connection_factory (Callable[[], Connection]): A callable that returns a connection to the database. + _multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing + the amount of system resources (number of connections and thread contention) when running in the Ray runner. + Defaults to None, which will let Daft decide based on the runner it is currently using. + + returns: + DataFrame: parsed DataFrame + """ + + io_config = context.get_context().daft_planning_config.default_io_config + + multithreaded_io = not context.get_context().is_ray_runner + + if use_native_downloader and native_downloader_available(url): + storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) + else: + storage_config = StorageConfig.python(PythonStorageConfig(io_config)) + + sql_operator = SQLScanOperator(sql, url, storage_config=storage_config) + + handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) + builder = LogicalPlanBuilder.from_tabular_scan_with_scan_operator(scan_operator=handle) + return DataFrame(builder) diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 56fda08a3a..fbd90a655a 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -72,6 +72,19 @@ class TableParseParquetOptions: coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns() +@dataclass(frozen=True) +class TableParseSQLOptions: + """Options for parsing SQL tables + + Args: + limit: Number of rows to read, or None to read all rows + offset: Number of rows to skip before reading + """ + + limit: int | None = None + offset: int | None = None + + @dataclass(frozen=True) class PartialPartitionMetadata: num_rows: None | int diff --git a/daft/sql/__init__.py b/daft/sql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py new file mode 100644 index 0000000000..e5986b6670 --- /dev/null +++ b/daft/sql/sql_scan.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import logging +import math +from collections.abc import Iterator + +import pyarrow as pa +from sqlalchemy import create_engine, text + +from daft.daft import ( + DatabaseSourceConfig, + FileFormatConfig, + Pushdowns, + ScanTask, + StorageConfig, +) +from daft.io.scan import PartitionField, ScanOperator +from daft.logical.schema import Schema + +logger = logging.getLogger(__name__) + + +class SQLScanOperator(ScanOperator): + MIN_ROWS_PER_SCAN_TASK = 50 # Would be better to have a memory limit instead of a row limit + + def __init__( + self, + sql: str, + url: str, + storage_config: StorageConfig, + ) -> None: + super().__init__() + self.sql = sql + self.url = url + self.storage_config = storage_config + self._limit_supported = self._check_limit_supported() + self._schema = self._get_schema() + + def _check_limit_supported(self) -> bool: + try: + with create_engine(self.url).connect() as connection: + connection.execute(text(f"SELECT * FROM ({self.sql}) AS subquery LIMIT 1 OFFSET 0")) + + return True + except Exception: + return False + + def _get_schema(self) -> Schema: + with create_engine(self.url).connect() as connection: + sql = f"SELECT * FROM ({self.sql}) AS subquery" + if self._limit_supported: + sql += " LIMIT 1 OFFSET 0" + + result = connection.execute(text(sql)) + + # Fetch the cursor from the result proxy to access column descriptions + cursor = result.cursor + + rows = cursor.fetchall() + columns = [column_description[0] for column_description in cursor.description] + + pydict = {column: [row[i] for row in rows] for i, column in enumerate(columns)} + pa_table = pa.Table.from_pydict(pydict) + + return Schema.from_pyarrow_schema(pa_table.schema) + + def _get_num_rows(self) -> int: + with create_engine(self.url).connect() as connection: + result = connection.execute(text(f"SELECT COUNT(*) FROM ({self.sql}) AS subquery")) + cursor = result.cursor + return cursor.fetchone()[0] + + def schema(self) -> Schema: + return self._schema + + def display_name(self) -> str: + return f"SQLScanOperator({self.sql})" + + def partitioning_keys(self) -> list[PartitionField]: + return [] + + def multiline_display(self) -> list[str]: + return [ + self.display_name(), + f"Schema = {self._schema}", + ] + + def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: + if not self._limit_supported: + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) + return iter( + [ + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + storage_config=self.storage_config, + pushdowns=pushdowns, + ) + ] + ) + + total_rows = self._get_num_rows() + num_scan_tasks = math.ceil(total_rows / self.MIN_ROWS_PER_SCAN_TASK) + num_rows_per_scan_task = total_rows // num_scan_tasks + + scan_tasks = [] + offset = 0 + for _ in range(num_scan_tasks): + limit = min(num_rows_per_scan_task, total_rows - offset) + file_format_config = FileFormatConfig.from_database_config( + DatabaseSourceConfig(self.sql, limit=limit, offset=offset) + ) + + scan_tasks.append( + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + storage_config=self.storage_config, + pushdowns=pushdowns, + ) + ) + offset += limit + + return iter(scan_tasks) + + def can_absorb_filter(self) -> bool: + return False + + def can_absorb_limit(self) -> bool: + return False + + def can_absorb_select(self) -> bool: + return True diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 4097b58dce..e5ae11276f 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -12,6 +12,7 @@ from pyarrow import dataset as pads from pyarrow import json as pajson from pyarrow import parquet as papq +from sqlalchemy import create_engine, text from daft.context import get_context from daft.daft import ( @@ -29,6 +30,7 @@ ) from daft.datatype import DataType from daft.expressions import ExpressionsProjection +from daft.expressions.expressions import Expression from daft.filesystem import ( _resolve_paths_and_filesystem, canonicalize_protocol, @@ -73,7 +75,6 @@ def _cast_table_to_schema(table: MicroPartition, read_options: TableReadOptions, their corresponding dtype in `schema`, and column ordering/inclusion matches `read_options.column_names` (if provided). """ pruned_schema = schema - # If reading only a subset of fields, prune the schema if read_options.column_names is not None: pruned_schema = Schema._from_fields([schema[name] for name in read_options.column_names]) @@ -212,6 +213,57 @@ def read_parquet( return _cast_table_to_schema(MicroPartition.from_arrow(table), read_options=read_options, schema=schema) +def read_sql( + sql: str, + url: str, + schema: Schema, + limit: int | None = None, + offset: int | None = None, + read_options: TableReadOptions = TableReadOptions(), + predicate: Expression | None = None, +) -> MicroPartition: + """Reads a MicroPartition from a SQL query + + Args: + sql (str): SQL query to execute + url (str): URL to the database + schema (Schema): Daft schema to read the SQL query into + + Returns: + MicroPartition: MicroPartition from SQL query + """ + columns = read_options.column_names + if columns is not None: + sql = f"SELECT {', '.join(columns)} FROM ({sql})" + else: + sql = f"SELECT * FROM ({sql})" + + if limit is not None: + sql = f"{sql} LIMIT {limit}" + + if offset is not None: + sql = f"{sql} OFFSET {offset}" + + with create_engine(url).connect() as connection: + result = connection.execute(text(sql)) + cursor = result.cursor + + rows = cursor.fetchall() + columns = [column_description[0] for column_description in cursor.description] + + pydict = {column: [row[i] for row in rows] for i, column in enumerate(columns)} + + mp = MicroPartition.from_pydict(pydict) + if predicate is not None: + mp = mp.filter(ExpressionsProjection([predicate])) + + num_rows = read_options.num_rows + if num_rows is not None: + mp = mp.head(num_rows) + + return _cast_table_to_schema(mp, read_options=read_options, schema=schema) + + class PACSVStreamHelper: def __init__(self, stream: pa.CSVStreamReader) -> None: self.stream = stream diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 654fe2695f..78d6d6c0d7 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -15,7 +15,9 @@ use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; use daft_parquet::read::{ read_parquet_bulk, read_parquet_metadata_bulk, ParquetSchemaInferenceOptions, }; -use daft_scan::file_format::{CsvSourceConfig, FileFormatConfig, ParquetSourceConfig}; +use daft_scan::file_format::{ + CsvSourceConfig, DatabaseSourceConfig, FileFormatConfig, ParquetSourceConfig, +}; use daft_scan::storage_config::{NativeStorageConfig, StorageConfig}; use daft_scan::{ChunkSpec, DataFileSource, Pushdowns, ScanTask}; use daft_table::Table; @@ -225,6 +227,12 @@ fn materialize_scan_task( ) .context(DaftCoreComputeSnafu)? } + FileFormatConfig::Database(_) => { + return Err(common_error::DaftError::TypeError( + "Native reads for Database file format not yet implemented".to_string(), + )) + .context(DaftCoreComputeSnafu); + } } } #[cfg(feature = "python")] @@ -300,6 +308,35 @@ fn materialize_scan_task( }) .collect::>>() })?, + FileFormatConfig::Database(DatabaseSourceConfig { sql, limit, offset }) => { + let py_expr = scan_task + .pushdowns + .filters + .as_ref() + .map(|p| (*p.as_ref()).clone().into()); + Python::with_gil(|py| { + urls.map(|url| { + crate::python::read_sql_into_py_table( + py, + sql, + url, + *limit, + *offset, + scan_task.schema.clone().into(), + scan_task + .pushdowns + .columns + .as_ref() + .map(|cols| cols.as_ref().clone()), + scan_task.pushdowns.limit, + py_expr.clone(), + ) + .map(|t| t.into()) + .context(PyIOSnafu) + }) + .collect::>>() + })? + } } } }; diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 2a7c487f27..178ff701e4 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -787,6 +787,56 @@ pub(crate) fn read_parquet_into_py_table( .extract() } +#[allow(clippy::too_many_arguments)] +pub(crate) fn read_sql_into_py_table( + py: Python, + sql: &str, + url: &str, + limit: Option, + offset: Option, + schema: PySchema, + include_columns: Option>, + num_rows: Option, + predicate: Option, +) -> PyResult { + let py_schema = py + .import(pyo3::intern!(py, "daft.logical.schema"))? + .getattr(pyo3::intern!(py, "Schema"))? + .getattr(pyo3::intern!(py, "_from_pyschema"))? + .call1((schema,))?; + let py_predicate = match predicate { + Some(p) => { + let expressions_mod = py.import(pyo3::intern!(py, "daft.expressions.expressions"))?; + Some( + expressions_mod + .getattr(pyo3::intern!(py, "Expression"))? + .getattr(pyo3::intern!(py, "_from_pyexpr"))? + .call1((p,))?, + ) + } + None => None, + }; + let read_options = py + .import(pyo3::intern!(py, "daft.runners.partitioning"))? + .getattr(pyo3::intern!(py, "TableReadOptions"))? + .call1((num_rows, include_columns))?; + py.import(pyo3::intern!(py, "daft.table.table_io"))? + .getattr(pyo3::intern!(py, "read_sql"))? + .call1(( + sql, + url, + py_schema, + limit, + offset, + read_options, + py_predicate, + ))? + .getattr(pyo3::intern!(py, "to_table"))? + .call0()? + .getattr(pyo3::intern!(py, "_table"))? + .extract() +} + impl From for PyMicroPartition { fn from(value: MicroPartition) -> Self { PyMicroPartition { diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 74bc2ccc7a..4c02484b7b 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -40,12 +40,15 @@ use { #[cfg(feature = "python")] pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { + use daft_scan::file_format::DatabaseSourceConfig; + parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 2f8ae4cc15..4e2778757d 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -6,7 +6,7 @@ use std::{ }; use common_daft_config::DaftExecutionConfig; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use daft_core::count_mode::CountMode; use daft_core::DataType; use daft_dsl::Expr; @@ -715,6 +715,7 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe input_physical.into(), ))) } + _ => unimplemented!(), } } } diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index bf9d0175f1..1a6c314b84 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -19,6 +19,7 @@ pub enum FileFormat { Parquet, Csv, Json, + Database, } impl FromStr for FileFormat { @@ -33,6 +34,8 @@ impl FromStr for FileFormat { Ok(Csv) } else if file_format.trim().eq_ignore_ascii_case("json") { Ok(Json) + } else if file_format.trim().eq_ignore_ascii_case("database") { + Ok(Database) } else { Err(DaftError::TypeError(format!( "FileFormat {} not supported!", @@ -50,6 +53,7 @@ impl From<&FileFormatConfig> for FileFormat { FileFormatConfig::Parquet(_) => Self::Parquet, FileFormatConfig::Csv(_) => Self::Csv, FileFormatConfig::Json(_) => Self::Json, + FileFormatConfig::Database(_) => Self::Database, } } } @@ -60,6 +64,7 @@ pub enum FileFormatConfig { Parquet(ParquetSourceConfig), Csv(CsvSourceConfig), Json(JsonSourceConfig), + Database(DatabaseSourceConfig), } impl FileFormatConfig { @@ -70,6 +75,7 @@ impl FileFormatConfig { Parquet(_) => "Parquet", Csv(_) => "Csv", Json(_) => "Json", + Database(_) => "Database", } } @@ -78,6 +84,7 @@ impl FileFormatConfig { Self::Parquet(source) => source.multiline_display(), Self::Csv(source) => source.multiline_display(), Self::Json(source) => source.multiline_display(), + Self::Database(source) => source.multiline_display(), } } } @@ -251,6 +258,45 @@ impl JsonSourceConfig { impl_bincode_py_state_serialization!(JsonSourceConfig); +/// Configuration for a Database data source. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft", get_all))] +pub struct DatabaseSourceConfig { + pub sql: String, + pub limit: Option, + pub offset: Option, +} + +impl DatabaseSourceConfig { + pub fn new_internal(sql: String, limit: Option, offset: Option) -> Self { + Self { sql, limit, offset } + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push(format!("SQL = {}", self.sql)); + if let Some(limit) = self.limit { + res.push(format!("Limit = {}", limit)); + } + if let Some(offset) = self.offset { + res.push(format!("Offset = {}", offset)); + } + res + } +} + +#[cfg(feature = "python")] +#[pymethods] +impl DatabaseSourceConfig { + /// Create a config for a Database data source. + #[new] + fn new(sql: &str, limit: Option, offset: Option) -> Self { + Self::new_internal(sql.to_string(), limit, offset) + } +} + +impl_bincode_py_state_serialization!(DatabaseSourceConfig); + /// Configuration for parsing a particular file format. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(transparent)] @@ -281,6 +327,12 @@ impl PyFileFormatConfig { Self(Arc::new(FileFormatConfig::Json(config))) } + /// Create a Database file format config. + #[staticmethod] + fn from_database_config(config: DatabaseSourceConfig) -> Self { + Self(Arc::new(FileFormatConfig::Database(config))) + } + /// Get the underlying data source config. #[getter] fn get_config(&self, py: Python) -> PyObject { @@ -290,6 +342,7 @@ impl PyFileFormatConfig { Parquet(config) => config.clone().into_py(py), Csv(config) => config.clone().into_py(py), Json(config) => config.clone().into_py(py), + Database(config) => config.clone().into_py(py), } } diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 26e9773df8..426fe289c8 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -201,6 +201,11 @@ impl GlobScanOperator { io_client, Some(io_stats), )?, + FileFormatConfig::Database(_) => { + return Err(DaftError::ValueError( + "Cannot glob a database source".to_string(), + )) + } }; let schema = match schema_hint { diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 1f7023a2b8..d275179f3f 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -135,32 +135,46 @@ pub enum DataFileSource { partition_spec: PartitionSpec, statistics: Option, }, + DatabaseDataFile { + path: String, + chunk_spec: Option, + size_bytes: Option, + metadata: Option, + partition_spec: Option, + statistics: Option, + }, } impl DataFileSource { pub fn get_path(&self) -> &str { match self { - Self::AnonymousDataFile { path, .. } | Self::CatalogDataFile { path, .. } => path, + Self::AnonymousDataFile { path, .. } + | Self::CatalogDataFile { path, .. } + | Self::DatabaseDataFile { path, .. } => path, } } pub fn get_chunk_spec(&self) -> Option<&ChunkSpec> { match self { Self::AnonymousDataFile { chunk_spec, .. } - | Self::CatalogDataFile { chunk_spec, .. } => chunk_spec.as_ref(), + | Self::CatalogDataFile { chunk_spec, .. } + | Self::DatabaseDataFile { chunk_spec, .. } => chunk_spec.as_ref(), } } pub fn get_size_bytes(&self) -> Option { match self { Self::AnonymousDataFile { size_bytes, .. } - | Self::CatalogDataFile { size_bytes, .. } => *size_bytes, + | Self::CatalogDataFile { size_bytes, .. } + | Self::DatabaseDataFile { size_bytes, .. } => *size_bytes, } } pub fn get_metadata(&self) -> Option<&TableMetadata> { match self { - Self::AnonymousDataFile { metadata, .. } => metadata.as_ref(), + Self::AnonymousDataFile { metadata, .. } | Self::DatabaseDataFile { metadata, .. } => { + metadata.as_ref() + } Self::CatalogDataFile { metadata, .. } => Some(metadata), } } @@ -168,13 +182,15 @@ impl DataFileSource { pub fn get_statistics(&self) -> Option<&TableStatistics> { match self { Self::AnonymousDataFile { statistics, .. } - | Self::CatalogDataFile { statistics, .. } => statistics.as_ref(), + | Self::CatalogDataFile { statistics, .. } + | Self::DatabaseDataFile { statistics, .. } => statistics.as_ref(), } } pub fn get_partition_spec(&self) -> Option<&PartitionSpec> { match self { - Self::AnonymousDataFile { partition_spec, .. } => partition_spec.as_ref(), + Self::AnonymousDataFile { partition_spec, .. } + | Self::DatabaseDataFile { partition_spec, .. } => partition_spec.as_ref(), Self::CatalogDataFile { partition_spec, .. } => Some(partition_spec), } } @@ -189,6 +205,14 @@ impl DataFileSource { metadata, partition_spec, statistics, + } + | Self::DatabaseDataFile { + path, + chunk_spec, + size_bytes, + metadata, + partition_spec, + statistics, } => { res.push(format!("Path = {}", path)); if let Some(chunk_spec) = chunk_spec { diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 1c28639905..5edb252719 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -319,6 +319,33 @@ pub mod pylib { Ok(Some(PyScanTask(scan_task.into()))) } + #[staticmethod] + pub fn sql_scan_task( + url: String, + file_format: PyFileFormatConfig, + schema: PySchema, + storage_config: PyStorageConfig, + pushdowns: Option, + ) -> PyResult { + let data_source = DataFileSource::DatabaseDataFile { + path: url, + chunk_spec: None, + size_bytes: None, + metadata: None, + partition_spec: None, + statistics: None, + }; + + let scan_task = ScanTask::new( + vec![data_source], + file_format.into(), + schema.schema, + storage_config.into(), + pushdowns.map(|p| p.0.as_ref().clone()).unwrap_or_default(), + ); + Ok(PyScanTask(scan_task.into())) + } + pub fn __repr__(&self) -> PyResult { Ok(format!("{:?}", self.0)) } diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index fe3d470a2f..a627d89105 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use daft_io::IOStatsContext; use daft_parquet::read::read_parquet_metadata; @@ -191,6 +191,12 @@ pub fn split_by_row_groups( curr_row_groups = Vec::new(); curr_size_bytes = 0; } + DataFileSource::DatabaseDataFile { .. } => { + return Err(DaftError::ValueError( + "Cannot split by row groups for database sources" + .to_string(), + )); + } }; new_tasks.push(Ok(ScanTask::new( diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index b330925848..d5c986918d 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -5,8 +5,10 @@ import decimal import json import os +import sqlite3 import tempfile import uuid +from typing import Generator import numpy as np import pandas as pd @@ -1032,3 +1034,42 @@ def test_create_dataframe_parquet_schema_hints_ignore_random_hint(valid_data: li pd_df = df.to_pandas() assert list(pd_df.columns) == COL_NAMES assert len(pd_df) == len(valid_data) + + +### +# SQL tests +### + + +@pytest.fixture(name="temp_database") +def temp_database() -> Generator[str, None, None]: + with tempfile.NamedTemporaryFile(suffix=".db") as file: + yield file.name + + +def test_create_dataframe_sql(valid_data: list[dict[str, float]], temp_database: str) -> None: + connection = sqlite3.connect(temp_database) + connection.execute( + f""" + CREATE TABLE iris ( + sepal_length REAL, + sepal_width REAL, + petal_length REAL, + petal_width REAL, + variety TEXT + ) + """ + ) + connection.executemany( + "INSERT INTO iris VALUES (?, ?, ?, ?, ?)", + [(d["sepal_length"], d["sepal_width"], d["petal_length"], d["petal_width"], d["variety"]) for d in valid_data], + ) + connection.commit() + connection.close() + + df = daft.read_sql("SELECT * FROM iris", f"sqlite:///{temp_database}") + assert df.column_names == COL_NAMES + + pd_df = df.to_pandas() + assert list(pd_df.columns) == COL_NAMES + assert len(pd_df) == len(valid_data) From 1e86d0f3e74ae3389ea649b9e0d3db07d8a7b036 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 10:39:34 -0800 Subject: [PATCH 02/30] int tests --- .github/workflows/python-package.yml | 71 +++++++++ daft/io/_sql.py | 32 +--- daft/sql/sql_scan.py | 42 ++--- daft/table/table_io.py | 13 +- daft/utils.py | 46 ++++++ pyproject.toml | 3 +- requirements-dev.txt | 5 + tests/dataframe/test_creation.py | 41 ----- tests/integration/sql/__init__.py | 0 .../sql/docker-compose/docker-compose.yml | 8 + tests/integration/sql/test_sql_lite.py | 149 ++++++++++++++++++ tests/integration/sql/test_trino.py | 15 ++ 12 files changed, 313 insertions(+), 112 deletions(-) create mode 100644 tests/integration/sql/__init__.py create mode 100644 tests/integration/sql/docker-compose/docker-compose.yml create mode 100644 tests/integration/sql/test_sql_lite.py create mode 100644 tests/integration/sql/test_trino.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index af281d5f58..3b8f98dafe 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -491,6 +491,77 @@ jobs: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK + integration-test-sql: + runs-on: ubuntu-latest + timeout-minutes: 30 + needs: + - integration-test-build + env: + package-name: getdaft + strategy: + fail-fast: false + matrix: + python-version: ['3.8'] # can't use 3.7 due to requiring anon mode for adlfs + daft-runner: [py, ray] + micropartitions: [1] + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: pip + cache-dependency-path: | + pyproject.toml + requirements-dev.txt + - name: Download built wheels + uses: actions/download-artifact@v3 + with: + name: wheels + path: dist + - name: Setup Virtual Env + run: | + python -m venv venv + echo "$GITHUB_WORKSPACE/venv/bin" >> $GITHUB_PATH + - name: Install Daft and dev dependencies + run: | + pip install --upgrade pip + pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall + rm -rf daft + - name: Spin up services + run: | + pushd ./tests/integration/sql/docker-compose/ + docker-compose -f ./docker-compose.yml up -d + popd + - name: Run sql integration tests + run: | + pytest tests/integration/sql -m 'integration' --durations=50 + env: + DAFT_RUNNER: ${{ matrix.daft-runner }} + DAFT_MICROPARTITIONS: ${{ matrix.micropartitions }} + - name: Send Slack notification on failure + uses: slackapi/slack-github-action@v1.24.0 + if: ${{ failure() && (github.ref == 'refs/heads/main') }} + with: + payload: | + { + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": ":rotating_light: [CI] Iceberg Integration Tests <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|workflow> *FAILED on main* :rotating_light:" + } + } + ] + } + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} + SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK + rust-tests: runs-on: ${{ matrix.os }}-latest timeout-minutes: 30 diff --git a/daft/io/_sql.py b/daft/io/_sql.py index 803093294a..3d2c786c8b 100644 --- a/daft/io/_sql.py +++ b/daft/io/_sql.py @@ -3,58 +3,34 @@ from daft import context from daft.api_annotations import PublicAPI -from daft.daft import ( - NativeStorageConfig, - PythonStorageConfig, - ScanOperatorHandle, - StorageConfig, -) +from daft.daft import PythonStorageConfig, ScanOperatorHandle, StorageConfig from daft.dataframe import DataFrame from daft.logical.builder import LogicalPlanBuilder from daft.sql.sql_scan import SQLScanOperator -def native_downloader_available(url: str) -> bool: - # TODO: We should be able to support native downloads via ConnectorX for compatible databases - return False - - @PublicAPI def read_sql( sql: str, url: str, - use_native_downloader: bool = False, - # schema_hints: Optional[Dict[str, DataType]] = None, ) -> DataFrame: """Creates a DataFrame from a SQL query Example: - >>> def create_connection(): - return sqlite3.connect("example.db") - >>> df = daft.read_sql("SELECT * FROM my_table", create_connection) + >>> df = daft.read_sql("SELECT * FROM my_table", "sqlite:///my_database.db") Args: sql (str): SQL query to execute - connection_factory (Callable[[], Connection]): A callable that returns a connection to the database. - _multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing - the amount of system resources (number of connections and thread contention) when running in the Ray runner. - Defaults to None, which will let Daft decide based on the runner it is currently using. + url (str): URL to the database returns: DataFrame: parsed DataFrame """ io_config = context.get_context().daft_planning_config.default_io_config - - multithreaded_io = not context.get_context().is_ray_runner - - if use_native_downloader and native_downloader_available(url): - storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) - else: - storage_config = StorageConfig.python(PythonStorageConfig(io_config)) + storage_config = StorageConfig.python(PythonStorageConfig(io_config)) sql_operator = SQLScanOperator(sql, url, storage_config=storage_config) - handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) builder = LogicalPlanBuilder.from_tabular_scan_with_scan_operator(scan_operator=handle) return DataFrame(builder) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index e5986b6670..fa2c35edf7 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -1,12 +1,8 @@ from __future__ import annotations -import logging import math from collections.abc import Iterator -import pyarrow as pa -from sqlalchemy import create_engine, text - from daft.daft import ( DatabaseSourceConfig, FileFormatConfig, @@ -16,8 +12,7 @@ ) from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema - -logger = logging.getLogger(__name__) +from daft.utils import execute_sql_query_to_pyarrow class SQLScanOperator(ScanOperator): @@ -38,37 +33,23 @@ def __init__( def _check_limit_supported(self) -> bool: try: - with create_engine(self.url).connect() as connection: - connection.execute(text(f"SELECT * FROM ({self.sql}) AS subquery LIMIT 1 OFFSET 0")) - + execute_sql_query_to_pyarrow(f"SELECT * FROM ({self.sql}) AS subquery LIMIT 1 OFFSET 0", self.url) return True except Exception: return False def _get_schema(self) -> Schema: - with create_engine(self.url).connect() as connection: - sql = f"SELECT * FROM ({self.sql}) AS subquery" - if self._limit_supported: - sql += " LIMIT 1 OFFSET 0" - - result = connection.execute(text(sql)) + sql = f"SELECT * FROM ({self.sql}) AS subquery" + if self._limit_supported: + sql += " LIMIT 1 OFFSET 0" - # Fetch the cursor from the result proxy to access column descriptions - cursor = result.cursor - - rows = cursor.fetchall() - columns = [column_description[0] for column_description in cursor.description] - - pydict = {column: [row[i] for row in rows] for i, column in enumerate(columns)} - pa_table = pa.Table.from_pydict(pydict) - - return Schema.from_pyarrow_schema(pa_table.schema) + pa_table = execute_sql_query_to_pyarrow(sql, self.url) + return Schema.from_pyarrow_schema(pa_table.schema) def _get_num_rows(self) -> int: - with create_engine(self.url).connect() as connection: - result = connection.execute(text(f"SELECT COUNT(*) FROM ({self.sql}) AS subquery")) - cursor = result.cursor - return cursor.fetchone()[0] + sql = f"SELECT COUNT(*) FROM ({self.sql}) AS subquery" + pa_table = execute_sql_query_to_pyarrow(sql, self.url) + return pa_table.column(0)[0].as_py() def schema(self) -> Schema: return self._schema @@ -107,11 +88,10 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: scan_tasks = [] offset = 0 for _ in range(num_scan_tasks): - limit = min(num_rows_per_scan_task, total_rows - offset) + limit = max(num_rows_per_scan_task, total_rows - offset) file_format_config = FileFormatConfig.from_database_config( DatabaseSourceConfig(self.sql, limit=limit, offset=offset) ) - scan_tasks.append( ScanTask.sql_scan_task( url=self.url, diff --git a/daft/table/table_io.py b/daft/table/table_io.py index e5ae11276f..19683a1487 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -12,7 +12,6 @@ from pyarrow import dataset as pads from pyarrow import json as pajson from pyarrow import parquet as papq -from sqlalchemy import create_engine, text from daft.context import get_context from daft.daft import ( @@ -44,6 +43,7 @@ ) from daft.series import Series from daft.table import MicroPartition +from daft.utils import execute_sql_query_to_pyarrow FileInput = Union[pathlib.Path, str, IO[bytes]] @@ -244,16 +244,7 @@ def read_sql( if offset is not None: sql = f"{sql} OFFSET {offset}" - with create_engine(url).connect() as connection: - result = connection.execute(text(sql)) - cursor = result.cursor - - rows = cursor.fetchall() - columns = [column_description[0] for column_description in cursor.description] - - pydict = {column: [row[i] for row in rows] for i, column in enumerate(columns)} - - mp = MicroPartition.from_pydict(pydict) + mp = MicroPartition.from_arrow(execute_sql_query_to_pyarrow(sql, url)) if predicate is not None: mp = mp.filter(ExpressionsProjection([predicate])) diff --git a/daft/utils.py b/daft/utils.py index a8efc9bf2e..2791826613 100644 --- a/daft/utils.py +++ b/daft/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import pickle import random import statistics @@ -7,6 +8,8 @@ import pyarrow as pa +logger = logging.getLogger(__name__) + ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) @@ -111,3 +114,46 @@ def pyarrow_supports_fixed_shape_tensor() -> bool: from daft.context import get_context return hasattr(pa, "fixed_shape_tensor") and (not get_context().is_ray_runner or ARROW_VERSION >= (13, 0, 0)) + + +def execute_sql_query_to_pyarrow_with_connectorx(sql: str, url: str) -> pa.Table: + import connectorx as cx + + logger.info(f"Using connectorx to execute sql: {sql}") + try: + table = cx.read_sql(conn=url, query=sql, return_type="arrow") + return table + except Exception as e: + raise RuntimeError(f"Failed to execute sql: {sql} with url: {url}") from e + + +def execute_sql_query_to_pyarrow_with_sqlalchemy(sql: str, url: str) -> pa.Table: + import pandas as pd + from sqlalchemy import create_engine, text + + logger.info(f"Using sqlalchemy to execute sql: {sql}") + try: + with create_engine(url).connect() as connection: + result = connection.execute(text(sql)) + df = pd.DataFrame(result.fetchall(), columns=result.keys()) + table = pa.Table.from_pandas(df) + return table + except Exception as e: + raise RuntimeError(f"Failed to execute sql: {sql} with url: {url}") from e + + +def execute_sql_query_to_pyarrow(sql: str, url: str) -> pa.Table: + # Supported DBs extracted from here https://github.com/sfu-db/connector-x/tree/7b3147436b7e20b96691348143d605e2249d6119?tab=readme-ov-file#sources + if ( + url.startswith("postgres") + or url.startswith("mysql") + or url.startswith("mssql") + or url.startswith("oracle") + or url.startswith("bigquery") + or url.startswith("sqlite") + or url.startswith("clickhouse") + or url.startswith("redshift") + ): + return execute_sql_query_to_pyarrow_with_connectorx(sql, url) + else: + return execute_sql_query_to_pyarrow_with_sqlalchemy(sql, url) diff --git a/pyproject.toml b/pyproject.toml index 965dbf5a14..fe21174dd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ readme = "README.rst" requires-python = ">=3.7" [project.optional-dependencies] -all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake]"] +all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake, sql]"] aws = ["boto3"] azure = [] deltalake = ["deltalake"] @@ -39,6 +39,7 @@ ray = [ # Explicitly install packaging. See issue: https://github.com/ray-project/ray/issues/34806 "packaging" ] +sql = ["connectorx", "sqlalchemy", "trino[sqlalchemy]"] viz = [] [project.urls] diff --git a/requirements-dev.txt b/requirements-dev.txt index 42f49446c8..6ec1b6b0fa 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -50,6 +50,11 @@ deltalake==0.13.0; platform_system != "Windows" and python_version < '3.8' # Databricks databricks-sdk==0.12.0 +#SQL +sqlalchemy==2.0.25; python_version < '3.8' +connectorx==0.3.2; python_version >= '3.8' +trino[sqlalchemy]==0.328.0; python_version >= '3.8' + # AWS s3fs==2023.1.0; python_version < '3.8' s3fs==2023.12.0; python_version >= '3.8' diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index d5c986918d..b330925848 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -5,10 +5,8 @@ import decimal import json import os -import sqlite3 import tempfile import uuid -from typing import Generator import numpy as np import pandas as pd @@ -1034,42 +1032,3 @@ def test_create_dataframe_parquet_schema_hints_ignore_random_hint(valid_data: li pd_df = df.to_pandas() assert list(pd_df.columns) == COL_NAMES assert len(pd_df) == len(valid_data) - - -### -# SQL tests -### - - -@pytest.fixture(name="temp_database") -def temp_database() -> Generator[str, None, None]: - with tempfile.NamedTemporaryFile(suffix=".db") as file: - yield file.name - - -def test_create_dataframe_sql(valid_data: list[dict[str, float]], temp_database: str) -> None: - connection = sqlite3.connect(temp_database) - connection.execute( - f""" - CREATE TABLE iris ( - sepal_length REAL, - sepal_width REAL, - petal_length REAL, - petal_width REAL, - variety TEXT - ) - """ - ) - connection.executemany( - "INSERT INTO iris VALUES (?, ?, ?, ?, ?)", - [(d["sepal_length"], d["sepal_width"], d["petal_length"], d["petal_width"], d["variety"]) for d in valid_data], - ) - connection.commit() - connection.close() - - df = daft.read_sql("SELECT * FROM iris", f"sqlite:///{temp_database}") - assert df.column_names == COL_NAMES - - pd_df = df.to_pandas() - assert list(pd_df.columns) == COL_NAMES - assert len(pd_df) == len(valid_data) diff --git a/tests/integration/sql/__init__.py b/tests/integration/sql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/sql/docker-compose/docker-compose.yml b/tests/integration/sql/docker-compose/docker-compose.yml new file mode 100644 index 0000000000..e463e51eeb --- /dev/null +++ b/tests/integration/sql/docker-compose/docker-compose.yml @@ -0,0 +1,8 @@ +version: '3.7' +services: + trino: + image: trinodb/trino + container_name: trino + ports: + - 8080:8080 + restart: unless-stopped diff --git a/tests/integration/sql/test_sql_lite.py b/tests/integration/sql/test_sql_lite.py new file mode 100644 index 0000000000..8670b49e3f --- /dev/null +++ b/tests/integration/sql/test_sql_lite.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import random +import sqlite3 +import tempfile + +import numpy as np +import pandas as pd +import pytest + +import daft + +COL_NAMES = ["sepal_length", "sepal_width", "petal_length", "petal_width", "variety"] +VARIETIES = ["Setosa", "Versicolor", "Virginica"] +CREATE_TABLE_SQL = """ +CREATE TABLE iris ( + sepal_length REAL, + sepal_width REAL, + petal_length REAL, + petal_width REAL, + variety TEXT +) +""" +INSERT_SQL = "INSERT INTO iris VALUES (?, ?, ?, ?, ?)" +NUM_ITEMS = 200 + + +def generate_test_items(num_items): + np.random.seed(42) + data = { + "sepal_length": np.round(np.random.uniform(4.3, 7.9, num_items), 1), + "sepal_width": np.round(np.random.uniform(2.0, 4.4, num_items), 1), + "petal_length": np.round(np.random.uniform(1.0, 6.9, num_items), 1), + "petal_width": np.round(np.random.uniform(0.1, 2.5, num_items), 1), + "variety": [random.choice(VARIETIES) for _ in range(num_items)], + } + return [ + ( + data["sepal_length"][i], + data["sepal_width"][i], + data["petal_length"][i], + data["petal_width"][i], + data["variety"][i], + ) + for i in range(num_items) + ] + + +# Fixture for temporary SQLite database +@pytest.fixture(scope="module") +def temp_sqllite_db(): + test_items = generate_test_items(NUM_ITEMS) + with tempfile.NamedTemporaryFile(suffix=".db") as file: + connection = sqlite3.connect(file.name) + connection.execute(CREATE_TABLE_SQL) + connection.executemany(INSERT_SQL, test_items) + connection.commit() + connection.close() + yield file.name + + +@pytest.mark.integration() +def test_sqllite_create_dataframe_ok(temp_sqllite_db) -> None: + df = daft.read_sql( + "SELECT * FROM iris", f"sqlite://{temp_sqllite_db}" + ) # path here only has 2 slashes instead of 3 because connectorx is used + pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") + + assert df.to_pandas().equals(pd_df) + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2, 3]) +def test_sqllite_partitioned_read(temp_sqllite_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM iris LIMIT {50 * num_partitions}", f"sqlite://{temp_sqllite_db}") + assert df.num_partitions() == num_partitions + df = df.collect() + assert len(df) == 50 * num_partitions + + # test with a number of rows that is not a multiple of 50 + df = daft.read_sql(f"SELECT * FROM iris LIMIT {50 * num_partitions + 1}", f"sqlite://{temp_sqllite_db}") + assert df.num_partitions() == num_partitions + 1 + df = df.collect() + assert len(df) == 50 * num_partitions + 1 + + +@pytest.mark.integration() +def test_sqllite_read_with_filter_pushdowns(temp_sqllite_db) -> None: + df = daft.read_sql("SELECT * FROM iris", f"sqlite://{temp_sqllite_db}") + df = df.where(df["sepal_length"] > 5.0) + df = df.where(df["sepal_width"] > 3.0) + + pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") + pd_df = pd_df[pd_df["sepal_length"] > 5.0] + pd_df = pd_df[pd_df["sepal_width"] > 3.0] + + df = df.to_pandas().sort_values("sepal_length", ascending=False).reset_index(drop=True) + pd_df = pd_df.sort_values("sepal_length", ascending=False).reset_index(drop=True) + assert df.equals(pd_df) + + +@pytest.mark.integration() +def test_sqllite_read_with_limit_pushdown(temp_sqllite_db) -> None: + df = daft.read_sql("SELECT * FROM iris", f"sqlite://{temp_sqllite_db}") + df = df.limit(100) + + pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") + pd_df = pd_df.head(100) + + df = df.to_pandas() + pd_df = pd_df.reset_index(drop=True) + assert df.equals(pd_df) + + +@pytest.mark.integration() +def test_sqllite_read_with_projection_pushdown(temp_sqllite_db) -> None: + df = daft.read_sql("SELECT * FROM iris", f"sqlite://{temp_sqllite_db}") + df = df.select(df["sepal_length"], df["variety"]) + + pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") + pd_df = pd_df[["sepal_length", "variety"]] + + df = df.to_pandas() + assert df.equals(pd_df) + + +@pytest.mark.integration() +def test_sqllite_read_with_all_pushdowns(temp_sqllite_db) -> None: + df = daft.read_sql("SELECT * FROM iris", f"sqlite://{temp_sqllite_db}") + df = df.where(df["sepal_length"] > 5.0) + df = df.where(df["sepal_width"] > 3.0) + df = df.limit(100) + df = df.select(df["sepal_length"]) + + pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") + pd_df = pd_df[pd_df["sepal_length"] > 5.0] + pd_df = pd_df[pd_df["sepal_width"] > 3.0] + pd_df = pd_df.head(100) + pd_df = pd_df[["sepal_length"]] + + df = df.to_pandas().sort_values("sepal_length", ascending=False).reset_index(drop=True) + pd_df = pd_df.sort_values("sepal_length", ascending=False).reset_index(drop=True) + assert df.equals(pd_df) + + +@pytest.mark.integration() +def test_sqllite_bad_url() -> None: + with pytest.raises(RuntimeError, match="Unable to execute sql"): + daft.read_sql("SELECT * FROM iris", "sqlite://") diff --git a/tests/integration/sql/test_trino.py b/tests/integration/sql/test_trino.py new file mode 100644 index 0000000000..53df3e564b --- /dev/null +++ b/tests/integration/sql/test_trino.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +import daft + +URL = "trino://user@localhost:8080/tpch" + + +@pytest.mark.integration() +def test_trino_create_dataframe_ok() -> None: + df = daft.read_sql("SELECT * FROM tpch.sf1.nation", URL) + pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", URL) + assert df.to_pandas().equals(pd_df) From 7d05a830affa8fa99d7f2b083416c00929edd3bf Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 10:58:44 -0800 Subject: [PATCH 03/30] sql alchemy version --- requirements-dev.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 6ec1b6b0fa..3d01d45eca 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -51,7 +51,8 @@ deltalake==0.13.0; platform_system != "Windows" and python_version < '3.8' databricks-sdk==0.12.0 #SQL -sqlalchemy==2.0.25; python_version < '3.8' +sqlalchemy==2.0.25; python_version >= '3.8' +sqlalchemy==1.4.51; python_version < '3.8' connectorx==0.3.2; python_version >= '3.8' trino[sqlalchemy]==0.328.0; python_version >= '3.8' From dd9ebbfc688e6f49078e1ac981d495ea213a207e Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 11:15:35 -0800 Subject: [PATCH 04/30] fix test --- daft/utils.py | 4 ++-- tests/integration/sql/test_sql_lite.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/daft/utils.py b/daft/utils.py index 2791826613..640eb07655 100644 --- a/daft/utils.py +++ b/daft/utils.py @@ -124,7 +124,7 @@ def execute_sql_query_to_pyarrow_with_connectorx(sql: str, url: str) -> pa.Table table = cx.read_sql(conn=url, query=sql, return_type="arrow") return table except Exception as e: - raise RuntimeError(f"Failed to execute sql: {sql} with url: {url}") from e + raise RuntimeError(f"Failed to execute sql: {sql} with url: {url}, error: {e}") from e def execute_sql_query_to_pyarrow_with_sqlalchemy(sql: str, url: str) -> pa.Table: @@ -139,7 +139,7 @@ def execute_sql_query_to_pyarrow_with_sqlalchemy(sql: str, url: str) -> pa.Table table = pa.Table.from_pandas(df) return table except Exception as e: - raise RuntimeError(f"Failed to execute sql: {sql} with url: {url}") from e + raise RuntimeError(f"Failed to execute sql: {sql} with url: {url}, error: {e}") from e def execute_sql_query_to_pyarrow(sql: str, url: str) -> pa.Table: diff --git a/tests/integration/sql/test_sql_lite.py b/tests/integration/sql/test_sql_lite.py index 8670b49e3f..0b4e1b5ed8 100644 --- a/tests/integration/sql/test_sql_lite.py +++ b/tests/integration/sql/test_sql_lite.py @@ -145,5 +145,5 @@ def test_sqllite_read_with_all_pushdowns(temp_sqllite_db) -> None: @pytest.mark.integration() def test_sqllite_bad_url() -> None: - with pytest.raises(RuntimeError, match="Unable to execute sql"): + with pytest.raises(RuntimeError, match="Failed to execute sql"): daft.read_sql("SELECT * FROM iris", "sqlite://") From f501dd1a55d9861edc8bd14345bbae8ef57a8ed5 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 12:57:39 -0800 Subject: [PATCH 05/30] retry --- tests/integration/sql/conftest.py | 19 +++++++++++++++++++ tests/integration/sql/test_trino.py | 9 ++++----- 2 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 tests/integration/sql/conftest.py diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py new file mode 100644 index 0000000000..ba15c2e42e --- /dev/null +++ b/tests/integration/sql/conftest.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import pytest +import sqlalchemy +import tenacity + +TRINO_URL = "trino://user@localhost:8080/tpch" + + +@tenacity.retry( + stop=tenacity.stop_after_delay(60), + retry=tenacity.retry_if_exception_type(sqlalchemy.exc.DBAPIError), + wait=tenacity.wait_fixed(5), + reraise=True, +) +@pytest.fixture(scope="session") +def check_db_server_initialized() -> None: + with sqlalchemy.create_engine(TRINO_URL).connect() as conn: + conn.execute(sqlalchemy.text("SELECT 1")) diff --git a/tests/integration/sql/test_trino.py b/tests/integration/sql/test_trino.py index 53df3e564b..932d0ed8ff 100644 --- a/tests/integration/sql/test_trino.py +++ b/tests/integration/sql/test_trino.py @@ -4,12 +4,11 @@ import pytest import daft - -URL = "trino://user@localhost:8080/tpch" +from tests.integration.sql.conftest import TRINO_URL @pytest.mark.integration() -def test_trino_create_dataframe_ok() -> None: - df = daft.read_sql("SELECT * FROM tpch.sf1.nation", URL) - pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", URL) +def test_trino_create_dataframe_ok(check_db_server_initialized) -> None: + df = daft.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) + pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) assert df.to_pandas().equals(pd_df) From fa085de850eafb1d4d382cae8d768625d0268789 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 14:19:40 -0800 Subject: [PATCH 06/30] retry all --- tests/integration/sql/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index ba15c2e42e..6a862fd766 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -9,7 +9,6 @@ @tenacity.retry( stop=tenacity.stop_after_delay(60), - retry=tenacity.retry_if_exception_type(sqlalchemy.exc.DBAPIError), wait=tenacity.wait_fixed(5), reraise=True, ) From 71af666ced5f2abac4faf8e92d1f6b608c059982 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 14:30:09 -0800 Subject: [PATCH 07/30] add try block --- tests/integration/sql/conftest.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index 6a862fd766..f06bbb35c8 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -14,5 +14,9 @@ ) @pytest.fixture(scope="session") def check_db_server_initialized() -> None: - with sqlalchemy.create_engine(TRINO_URL).connect() as conn: - conn.execute(sqlalchemy.text("SELECT 1")) + try: + with sqlalchemy.create_engine(TRINO_URL).connect() as conn: + conn.execute(sqlalchemy.text("SELECT 1")) + except Exception as e: + print(f"Connection failed with exception: {e}") + raise From c8368197c9a6e2ffa0a60bb300981c5026322536 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 18:49:49 -0800 Subject: [PATCH 08/30] move retries out of fixture --- tests/integration/sql/conftest.py | 33 +++++++++++++--- tests/integration/sql/test_sql_lite.py | 55 +++++++------------------- tests/integration/sql/test_trino.py | 7 ++-- 3 files changed, 46 insertions(+), 49 deletions(-) diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index f06bbb35c8..82ca03e8d7 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -1,22 +1,45 @@ from __future__ import annotations +import random + +import numpy as np import pytest import sqlalchemy import tenacity TRINO_URL = "trino://user@localhost:8080/tpch" +NUM_TEST_ROWS = 200 + + +@pytest.fixture(scope="session") +def test_items(): + np.random.seed(42) + data = { + "sepal_length": np.round(np.random.uniform(4.3, 7.9, NUM_TEST_ROWS), 1), + "sepal_width": np.round(np.random.uniform(2.0, 4.4, NUM_TEST_ROWS), 1), + "petal_length": np.round(np.random.uniform(1.0, 6.9, NUM_TEST_ROWS), 1), + "petal_width": np.round(np.random.uniform(0.1, 2.5, NUM_TEST_ROWS), 1), + "variety": [random.choice(["Setosa", "Versicolor", "Virginica"]) for _ in range(NUM_TEST_ROWS)], + } + return data + @tenacity.retry( stop=tenacity.stop_after_delay(60), wait=tenacity.wait_fixed(5), reraise=True, ) +def check_database_connection(url) -> None: + with sqlalchemy.create_engine(url).connect() as conn: + conn.execute("SELECT 1") + + @pytest.fixture(scope="session") -def check_db_server_initialized() -> None: +@pytest.mark.parametrize("url", [TRINO_URL]) +def check_db_server_initialized(url) -> bool: try: - with sqlalchemy.create_engine(TRINO_URL).connect() as conn: - conn.execute(sqlalchemy.text("SELECT 1")) + check_database_connection(url) + return True except Exception as e: - print(f"Connection failed with exception: {e}") - raise + pytest.fail(f"Failed to connect to {url}: {e}") diff --git a/tests/integration/sql/test_sql_lite.py b/tests/integration/sql/test_sql_lite.py index 0b4e1b5ed8..990ce87603 100644 --- a/tests/integration/sql/test_sql_lite.py +++ b/tests/integration/sql/test_sql_lite.py @@ -1,59 +1,32 @@ from __future__ import annotations -import random import sqlite3 import tempfile -import numpy as np import pandas as pd import pytest import daft -COL_NAMES = ["sepal_length", "sepal_width", "petal_length", "petal_width", "variety"] -VARIETIES = ["Setosa", "Versicolor", "Virginica"] -CREATE_TABLE_SQL = """ -CREATE TABLE iris ( - sepal_length REAL, - sepal_width REAL, - petal_length REAL, - petal_width REAL, - variety TEXT -) -""" -INSERT_SQL = "INSERT INTO iris VALUES (?, ?, ?, ?, ?)" -NUM_ITEMS = 200 - - -def generate_test_items(num_items): - np.random.seed(42) - data = { - "sepal_length": np.round(np.random.uniform(4.3, 7.9, num_items), 1), - "sepal_width": np.round(np.random.uniform(2.0, 4.4, num_items), 1), - "petal_length": np.round(np.random.uniform(1.0, 6.9, num_items), 1), - "petal_width": np.round(np.random.uniform(0.1, 2.5, num_items), 1), - "variety": [random.choice(VARIETIES) for _ in range(num_items)], - } - return [ - ( - data["sepal_length"][i], - data["sepal_width"][i], - data["petal_length"][i], - data["petal_width"][i], - data["variety"][i], - ) - for i in range(num_items) - ] - # Fixture for temporary SQLite database @pytest.fixture(scope="module") -def temp_sqllite_db(): - test_items = generate_test_items(NUM_ITEMS) +def temp_sqllite_db(test_items): + data = list( + zip( + test_items["sepal_length"], + test_items["sepal_width"], + test_items["petal_length"], + test_items["petal_width"], + test_items["variety"], + ) + ) with tempfile.NamedTemporaryFile(suffix=".db") as file: connection = sqlite3.connect(file.name) - connection.execute(CREATE_TABLE_SQL) - connection.executemany(INSERT_SQL, test_items) + connection.execute( + "CREATE TABLE iris (sepal_length REAL, sepal_width REAL, petal_length REAL, petal_width REAL, variety TEXT)" + ) + connection.executemany("INSERT INTO iris VALUES (?, ?, ?, ?, ?)", data) connection.commit() connection.close() yield file.name diff --git a/tests/integration/sql/test_trino.py b/tests/integration/sql/test_trino.py index 932d0ed8ff..99efa8e990 100644 --- a/tests/integration/sql/test_trino.py +++ b/tests/integration/sql/test_trino.py @@ -9,6 +9,7 @@ @pytest.mark.integration() def test_trino_create_dataframe_ok(check_db_server_initialized) -> None: - df = daft.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) - pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) - assert df.to_pandas().equals(pd_df) + if check_db_server_initialized: + df = daft.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) + pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) + assert df.to_pandas().equals(pd_df) From f24fa9d8b51257dad88418739911e3e215970bfa Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 19:01:20 -0800 Subject: [PATCH 09/30] move retries out of fixture --- tests/integration/sql/conftest.py | 21 ++++++++++++------- tests/integration/sql/test_databases.py | 15 +++++++++++++ .../{test_sql_lite.py => test_operations.py} | 2 +- tests/integration/sql/test_trino.py | 15 ------------- 4 files changed, 29 insertions(+), 24 deletions(-) create mode 100644 tests/integration/sql/test_databases.py rename tests/integration/sql/{test_sql_lite.py => test_operations.py} (99%) delete mode 100644 tests/integration/sql/test_trino.py diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index 82ca03e8d7..f20e7bb2f6 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -1,13 +1,14 @@ from __future__ import annotations import random +from typing import Generator import numpy as np import pytest import sqlalchemy import tenacity -TRINO_URL = "trino://user@localhost:8080/tpch" +URLS = {"trino": "trino://user@localhost:8080/tpch"} NUM_TEST_ROWS = 200 @@ -36,10 +37,14 @@ def check_database_connection(url) -> None: @pytest.fixture(scope="session") -@pytest.mark.parametrize("url", [TRINO_URL]) -def check_db_server_initialized(url) -> bool: - try: - check_database_connection(url) - return True - except Exception as e: - pytest.fail(f"Failed to connect to {url}: {e}") +def db_url() -> Generator[str, None, None]: + for url in URLS.values(): + try: + check_database_connection(url) + except Exception as e: + pytest.fail(f"Failed to connect to {url}: {e}") + + def db_url(db): + return URLS[db] + + yield db_url diff --git a/tests/integration/sql/test_databases.py b/tests/integration/sql/test_databases.py new file mode 100644 index 0000000000..e98f11d72e --- /dev/null +++ b/tests/integration/sql/test_databases.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +import daft + + +@pytest.mark.integration() +def test_trino_create_dataframe_ok(db_url) -> None: + url = db_url("trino") + df = daft.read_sql("SELECT * FROM tpch.sf1.nation", url) + pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", url) + + assert df.equals(pd_df) diff --git a/tests/integration/sql/test_sql_lite.py b/tests/integration/sql/test_operations.py similarity index 99% rename from tests/integration/sql/test_sql_lite.py rename to tests/integration/sql/test_operations.py index 990ce87603..c62df97214 100644 --- a/tests/integration/sql/test_sql_lite.py +++ b/tests/integration/sql/test_operations.py @@ -36,7 +36,7 @@ def temp_sqllite_db(test_items): def test_sqllite_create_dataframe_ok(temp_sqllite_db) -> None: df = daft.read_sql( "SELECT * FROM iris", f"sqlite://{temp_sqllite_db}" - ) # path here only has 2 slashes instead of 3 because connectorx is used + ) # path here only has 2 slashes instead of 3 because connectorx uses 2 slashes pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") assert df.to_pandas().equals(pd_df) diff --git a/tests/integration/sql/test_trino.py b/tests/integration/sql/test_trino.py deleted file mode 100644 index 99efa8e990..0000000000 --- a/tests/integration/sql/test_trino.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -import pandas as pd -import pytest - -import daft -from tests.integration.sql.conftest import TRINO_URL - - -@pytest.mark.integration() -def test_trino_create_dataframe_ok(check_db_server_initialized) -> None: - if check_db_server_initialized: - df = daft.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) - pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) - assert df.to_pandas().equals(pd_df) From 1237098238b67bebc30c4c503cf1a9473b4404c3 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 19:02:29 -0800 Subject: [PATCH 10/30] add text to query --- tests/integration/sql/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index f20e7bb2f6..cadde8c2c8 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -33,7 +33,7 @@ def test_items(): ) def check_database_connection(url) -> None: with sqlalchemy.create_engine(url).connect() as conn: - conn.execute("SELECT 1") + conn.execute(sqlalchemy.text("SELECT 1")) @pytest.fixture(scope="session") From f05ce21c60664ec9cbecad839801de715eb8acd8 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Feb 2024 21:47:04 -0800 Subject: [PATCH 11/30] add text to query --- tests/integration/sql/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index cadde8c2c8..de254ddd77 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -33,7 +33,7 @@ def test_items(): ) def check_database_connection(url) -> None: with sqlalchemy.create_engine(url).connect() as conn: - conn.execute(sqlalchemy.text("SELECT 1")) + conn.execute(sqlalchemy.text("SELECT * FROM tpch.sf1.nation")) @pytest.fixture(scope="session") From 2cb90faa3dc619b173b9bae0e88056334d2bea63 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Sat, 24 Feb 2024 10:17:17 -0800 Subject: [PATCH 12/30] fix assertion --- tests/integration/sql/test_databases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/sql/test_databases.py b/tests/integration/sql/test_databases.py index e98f11d72e..28faa79321 100644 --- a/tests/integration/sql/test_databases.py +++ b/tests/integration/sql/test_databases.py @@ -12,4 +12,4 @@ def test_trino_create_dataframe_ok(db_url) -> None: df = daft.read_sql("SELECT * FROM tpch.sf1.nation", url) pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", url) - assert df.equals(pd_df) + assert df.to_pandas().equals(pd_df) From f1836f1bd608c52a91ed317d8835a741f9497e84 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Sat, 24 Feb 2024 10:33:52 -0800 Subject: [PATCH 13/30] yay micropartitions always 1 --- .github/workflows/python-package.yml | 2 -- daft/io/_sql.py | 2 +- src/daft-plan/src/planner.rs | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 3b8f98dafe..45db70102d 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -503,7 +503,6 @@ jobs: matrix: python-version: ['3.8'] # can't use 3.7 due to requiring anon mode for adlfs daft-runner: [py, ray] - micropartitions: [1] steps: - uses: actions/checkout@v4 with: @@ -541,7 +540,6 @@ jobs: pytest tests/integration/sql -m 'integration' --durations=50 env: DAFT_RUNNER: ${{ matrix.daft-runner }} - DAFT_MICROPARTITIONS: ${{ matrix.micropartitions }} - name: Send Slack notification on failure uses: slackapi/slack-github-action@v1.24.0 if: ${{ failure() && (github.ref == 'refs/heads/main') }} diff --git a/daft/io/_sql.py b/daft/io/_sql.py index 3d2c786c8b..a3fee03913 100644 --- a/daft/io/_sql.py +++ b/daft/io/_sql.py @@ -32,5 +32,5 @@ def read_sql( sql_operator = SQLScanOperator(sql, url, storage_config=storage_config) handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) - builder = LogicalPlanBuilder.from_tabular_scan_with_scan_operator(scan_operator=handle) + builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) return DataFrame(builder) diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 4e2778757d..5776998d39 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -6,7 +6,7 @@ use std::{ }; use common_daft_config::DaftExecutionConfig; -use common_error::{DaftError, DaftResult}; +use common_error::DaftResult; use daft_core::count_mode::CountMode; use daft_core::DataType; use daft_dsl::Expr; From dfc0fe9117ee2afd65ff30d5d70df5e2dbfc8d16 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 26 Feb 2024 08:50:52 -0800 Subject: [PATCH 14/30] add more integration tests + refactor --- daft/daft.pyi | 5 +- daft/runners/partitioning.py | 6 +- daft/sql/sql_reader.py | 99 ++++++++++++++ daft/sql/sql_scan.py | 50 +++---- daft/table/table_io.py | 41 +++--- daft/utils.py | 43 ------ requirements-dev.txt | 3 +- src/daft-dsl/src/expr.rs | 80 ++++++++++++ src/daft-micropartition/src/micropartition.rs | 18 ++- src/daft-micropartition/src/python.rs | 41 +++--- src/daft-scan/src/file_format.rs | 27 +++- tests/integration/sql/conftest.py | 93 ++++++++----- .../sql/docker-compose/docker-compose.yml | 30 ++++- tests/integration/sql/test_databases.py | 15 --- tests/integration/sql/test_operations.py | 122 ------------------ tests/integration/sql/test_sql.py | 121 +++++++++++++++++ 16 files changed, 508 insertions(+), 286 deletions(-) create mode 100644 daft/sql/sql_reader.py delete mode 100644 tests/integration/sql/test_databases.py delete mode 100644 tests/integration/sql/test_operations.py create mode 100644 tests/integration/sql/test_sql.py diff --git a/daft/daft.pyi b/daft/daft.pyi index 55e2f81320..bfc7bfd2b6 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -225,8 +225,11 @@ class DatabaseSourceConfig: sql: str limit: int | None offset: int | None + limit_before_offset: bool | None - def __init__(self, sql: str, limit: int | None = None, offset: int | None = None): ... + def __init__( + self, sql: str, limit: int | None = None, offset: int | None = None, limit_before_offset: bool | None = None + ): ... class FileFormatConfig: """ diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index fbd90a655a..fc41d59354 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -73,7 +73,7 @@ class TableParseParquetOptions: @dataclass(frozen=True) -class TableParseSQLOptions: +class TableReadSQLOptions: """Options for parsing SQL tables Args: @@ -83,6 +83,10 @@ class TableParseSQLOptions: limit: int | None = None offset: int | None = None + limit_before_offset: bool | None = None + + predicate_sql: str | None = None + predicate_expression: Expression | None = None @dataclass(frozen=True) diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py new file mode 100644 index 0000000000..6f11c0bb93 --- /dev/null +++ b/daft/sql/sql_reader.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import logging + +import pyarrow as pa + +logger = logging.getLogger(__name__) + + +class SQLReader: + def __init__( + self, + sql: str, + url: str, + limit: int | None = None, + offset: int | None = None, + limit_before_offset: bool | None = None, + projection: list[str] | None = None, + predicate: str | None = None, + ) -> None: + if limit is not None and offset is not None and limit_before_offset is None: + raise ValueError("limit_before_offset must be specified when limit and offset are both specified") + + self.sql = sql + self.url = url + self.limit = limit + self.offset = offset + self.limit_before_offset = limit_before_offset + self.projection = projection + self.predicate = predicate + + def get_num_rows(self) -> int: + sql = f"SELECT COUNT(*) FROM ({self.sql}) AS subquery" + pa_table = self._execute_sql_query(sql) + return pa_table.column(0)[0].as_py() + + def read(self) -> pa.Table: + sql = self._construct_sql_query() + return self._execute_sql_query(sql) + + def _construct_sql_query(self) -> str: + if self.projection is not None: + columns = ", ".join(self.projection) + else: + columns = "*" + + sql = f"SELECT {columns} FROM ({self.sql}) AS subquery" + + if self.predicate is not None: + sql += f" WHERE {self.predicate}" + + if self.limit is not None and self.offset is not None: + if self.limit_before_offset: + sql += f" LIMIT {self.limit} OFFSET {self.offset}" + else: + sql += f" OFFSET {self.offset} LIMIT {self.limit}" + elif self.limit is not None: + sql += f" LIMIT {self.limit}" + elif self.offset is not None: + sql += f" OFFSET {self.offset}" + + return sql + + def _execute_sql_query(self, sql: str) -> pa.Table: + # Supported DBs extracted from here https://github.com/sfu-db/connector-x/tree/7b3147436b7e20b96691348143d605e2249d6119?tab=readme-ov-file#sources + supported_dbs = {"postgres", "mysql", "mssql", "oracle", "bigquery", "sqlite", "clickhouse", "redshift"} + + # Extract the database type from the URL + db_type = self.url.split("://")[0] + + # Check if the database type is supported + if db_type in supported_dbs: + return self._execute_sql_query_with_connectorx(sql) + else: + return self._execute_sql_query_with_sqlalchemy(sql) + + def _execute_sql_query_with_connectorx(self, sql: str) -> pa.Table: + import connectorx as cx + + logger.info(f"Using connectorx to execute sql: {sql}") + try: + table = cx.read_sql(conn=self.url, query=sql, return_type="arrow") + return table + except Exception as e: + raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e + + def _execute_sql_query_with_sqlalchemy(self, sql: str) -> pa.Table: + import pandas as pd + from sqlalchemy import create_engine, text + + logger.info(f"Using sqlalchemy to execute sql: {sql}") + try: + with create_engine(self.url).connect() as connection: + result = connection.execute(text(sql)) + df = pd.DataFrame(result.fetchall(), columns=result.keys()) + table = pa.Table.from_pandas(df) + return table + except Exception as e: + raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index fa2c35edf7..e4089401ee 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -12,11 +12,11 @@ ) from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema -from daft.utils import execute_sql_query_to_pyarrow +from daft.sql.sql_reader import SQLReader class SQLScanOperator(ScanOperator): - MIN_ROWS_PER_SCAN_TASK = 50 # Would be better to have a memory limit instead of a row limit + MIN_ROWS_PER_SCAN_TASK = 50 # TODO: Would be better to have a memory limit instead of a row limit def __init__( self, @@ -28,34 +28,37 @@ def __init__( self.sql = sql self.url = url self.storage_config = storage_config - self._limit_supported = self._check_limit_supported() + self._limit_and_offset_supported, self._limit_before_offset = self._check_limit_and_offset_supported() self._schema = self._get_schema() - def _check_limit_supported(self) -> bool: + def _check_limit_and_offset_supported(self) -> tuple[bool, bool]: try: - execute_sql_query_to_pyarrow(f"SELECT * FROM ({self.sql}) AS subquery LIMIT 1 OFFSET 0", self.url) - return True + # Try to read 1 row with limit before offset + SQLReader(self.sql, self.url, limit=1, offset=0, limit_before_offset=True).read() + return (True, True) except Exception: - return False + try: + # Try to read 1 row with limit after offset + SQLReader(self.sql, self.url, offset=0, limit=1, limit_before_offset=False).read() + return (True, False) + except Exception: + # If both fail, then limit and offset are not supported + return (False, False) def _get_schema(self) -> Schema: - sql = f"SELECT * FROM ({self.sql}) AS subquery" - if self._limit_supported: - sql += " LIMIT 1 OFFSET 0" - - pa_table = execute_sql_query_to_pyarrow(sql, self.url) + if self._limit_and_offset_supported: + pa_table = SQLReader( + self.sql, self.url, limit=1, offset=0, limit_before_offset=self._limit_before_offset + ).read() + else: + pa_table = SQLReader(self.sql, self.url).read() return Schema.from_pyarrow_schema(pa_table.schema) - def _get_num_rows(self) -> int: - sql = f"SELECT COUNT(*) FROM ({self.sql}) AS subquery" - pa_table = execute_sql_query_to_pyarrow(sql, self.url) - return pa_table.column(0)[0].as_py() - def schema(self) -> Schema: return self._schema def display_name(self) -> str: - return f"SQLScanOperator({self.sql})" + return f"SQLScanOperator(sql={self.sql}, url={self.url})" def partitioning_keys(self) -> list[PartitionField]: return [] @@ -67,7 +70,8 @@ def multiline_display(self) -> list[str]: ] def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: - if not self._limit_supported: + if not self._limit_and_offset_supported: + # If limit and offset are not supported, then we can't parallelize the scan, so we just return a single scan task file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) return iter( [ @@ -81,7 +85,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: ] ) - total_rows = self._get_num_rows() + total_rows = SQLReader(self.sql, self.url).get_num_rows() num_scan_tasks = math.ceil(total_rows / self.MIN_ROWS_PER_SCAN_TASK) num_rows_per_scan_task = total_rows // num_scan_tasks @@ -90,7 +94,9 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: for _ in range(num_scan_tasks): limit = max(num_rows_per_scan_task, total_rows - offset) file_format_config = FileFormatConfig.from_database_config( - DatabaseSourceConfig(self.sql, limit=limit, offset=offset) + DatabaseSourceConfig( + self.sql, limit=limit, offset=offset, limit_before_offset=self._limit_before_offset + ) ) scan_tasks.append( ScanTask.sql_scan_task( @@ -112,4 +118,4 @@ def can_absorb_limit(self) -> bool: return False def can_absorb_select(self) -> bool: - return True + return False diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 19683a1487..5f2207bc33 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -29,7 +29,6 @@ ) from daft.datatype import DataType from daft.expressions import ExpressionsProjection -from daft.expressions.expressions import Expression from daft.filesystem import ( _resolve_paths_and_filesystem, canonicalize_protocol, @@ -40,10 +39,11 @@ TableParseCSVOptions, TableParseParquetOptions, TableReadOptions, + TableReadSQLOptions, ) from daft.series import Series +from daft.sql.sql_reader import SQLReader from daft.table import MicroPartition -from daft.utils import execute_sql_query_to_pyarrow FileInput = Union[pathlib.Path, str, IO[bytes]] @@ -217,10 +217,8 @@ def read_sql( sql: str, url: str, schema: Schema, - limit: int | None = None, - offset: int | None = None, + sql_options: TableReadSQLOptions = TableReadSQLOptions(), read_options: TableReadOptions = TableReadOptions(), - predicate: Expression | None = None, ) -> MicroPartition: """Reads a MicroPartition from a SQL query @@ -232,25 +230,22 @@ def read_sql( Returns: MicroPartition: MicroPartition from SQL query """ - columns = read_options.column_names - if columns is not None: - sql = f"SELECT {', '.join(columns)} FROM ({sql})" - else: - sql = f"SELECT * FROM ({sql})" - - if limit is not None: - sql = f"{sql} LIMIT {limit}" - - if offset is not None: - sql = f"{sql} OFFSET {offset}" + pa_table = SQLReader( + sql, + url, + limit=sql_options.limit, + offset=sql_options.offset, + limit_before_offset=sql_options.limit_before_offset, + projection=read_options.column_names, + predicate=sql_options.predicate_sql, + ).read() + mp = MicroPartition.from_arrow(pa_table) + + if sql_options.predicate_sql is None and sql_options.predicate_expression is not None: + mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression])) - mp = MicroPartition.from_arrow(execute_sql_query_to_pyarrow(sql, url)) - if predicate is not None: - mp = mp.filter(ExpressionsProjection([predicate])) - - num_rows = read_options.num_rows - if num_rows is not None: - mp = mp.head(num_rows) + if read_options.num_rows is not None: + mp = mp.head(read_options.num_rows) return _cast_table_to_schema(mp, read_options=read_options, schema=schema) diff --git a/daft/utils.py b/daft/utils.py index 640eb07655..8580323f42 100644 --- a/daft/utils.py +++ b/daft/utils.py @@ -114,46 +114,3 @@ def pyarrow_supports_fixed_shape_tensor() -> bool: from daft.context import get_context return hasattr(pa, "fixed_shape_tensor") and (not get_context().is_ray_runner or ARROW_VERSION >= (13, 0, 0)) - - -def execute_sql_query_to_pyarrow_with_connectorx(sql: str, url: str) -> pa.Table: - import connectorx as cx - - logger.info(f"Using connectorx to execute sql: {sql}") - try: - table = cx.read_sql(conn=url, query=sql, return_type="arrow") - return table - except Exception as e: - raise RuntimeError(f"Failed to execute sql: {sql} with url: {url}, error: {e}") from e - - -def execute_sql_query_to_pyarrow_with_sqlalchemy(sql: str, url: str) -> pa.Table: - import pandas as pd - from sqlalchemy import create_engine, text - - logger.info(f"Using sqlalchemy to execute sql: {sql}") - try: - with create_engine(url).connect() as connection: - result = connection.execute(text(sql)) - df = pd.DataFrame(result.fetchall(), columns=result.keys()) - table = pa.Table.from_pandas(df) - return table - except Exception as e: - raise RuntimeError(f"Failed to execute sql: {sql} with url: {url}, error: {e}") from e - - -def execute_sql_query_to_pyarrow(sql: str, url: str) -> pa.Table: - # Supported DBs extracted from here https://github.com/sfu-db/connector-x/tree/7b3147436b7e20b96691348143d605e2249d6119?tab=readme-ov-file#sources - if ( - url.startswith("postgres") - or url.startswith("mysql") - or url.startswith("mssql") - or url.startswith("oracle") - or url.startswith("bigquery") - or url.startswith("sqlite") - or url.startswith("clickhouse") - or url.startswith("redshift") - ): - return execute_sql_query_to_pyarrow_with_connectorx(sql, url) - else: - return execute_sql_query_to_pyarrow_with_sqlalchemy(sql, url) diff --git a/requirements-dev.txt b/requirements-dev.txt index 3d01d45eca..a2f8a44b77 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -52,9 +52,10 @@ databricks-sdk==0.12.0 #SQL sqlalchemy==2.0.25; python_version >= '3.8' -sqlalchemy==1.4.51; python_version < '3.8' connectorx==0.3.2; python_version >= '3.8' trino[sqlalchemy]==0.328.0; python_version >= '3.8' +PyMySQL==1.1.0; python_version >= '3.8' +psycopg2==2.9.9; python_version >= '3.8' # AWS s3fs==2023.1.0; python_version < '3.8' diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index d66ce7df55..352cb6f535 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -568,6 +568,86 @@ impl Expr { _ => None, } } + + pub fn to_sql(&self) -> Option { + match self { + Expr::Column(name) => Some(name.to_string()), + Expr::Literal(lit) => match lit { + lit::LiteralValue::Series(series) => match series.data_type() { + DataType::Utf8 => { + let s = lit.to_string(); + let trimmed = s.trim_matches(|c| c == '[' || c == ']').to_string(); + let quoted = trimmed + .split(", ") + .map(|s| format!("'{}'", s)) + .collect::>(); + Some(quoted.join(", ")) + } + _ => Some( + lit.to_string() + .trim_matches(|c| c == '[' || c == ']') + .to_string(), + ), + }, + _ => Some(lit.to_string().replace('\"', "'")), + }, + Expr::Alias(inner, ..) => { + let inner_sql = inner.to_sql()?; + Some(inner_sql) + } + Expr::BinaryOp { op, left, right } => { + let left_sql = left.to_sql()?; + let right_sql = right.to_sql()?; + let op = match op { + Operator::Eq => "=".to_string(), + Operator::NotEq => "!=".to_string(), + Operator::Lt => "<".to_string(), + Operator::LtEq => "<=".to_string(), + Operator::Gt => ">".to_string(), + Operator::GtEq => ">=".to_string(), + Operator::And => "AND".to_string(), + Operator::Or => "OR".to_string(), + _ => return None, + }; + Some(format!("{} {} {}", left_sql, op, right_sql)) + } + Expr::Not(inner) => { + let inner_sql = inner.to_sql()?; + Some(format!("NOT {}", inner_sql)) + } + Expr::IsNull(inner) => { + let inner_sql = inner.to_sql()?; + Some(format!("{} IS NULL", inner_sql)) + } + Expr::NotNull(inner) => { + let inner_sql = inner.to_sql()?; + Some(format!("{} IS NOT NULL", inner_sql)) + } + // TODO: Implement SQL translations for these expressions + Expr::IfElse { + if_true, + if_false, + predicate, + } => { + let if_true_sql = if_true.to_sql()?; + let if_false_sql = if_false.to_sql()?; + let predicate_sql = predicate.to_sql()?; + Some(format!( + "CASE WHEN {} THEN {} ELSE {} END", + predicate_sql, if_true_sql, if_false_sql + )) + } + Expr::IsIn(inner, items) => { + let inner_sql = inner.to_sql()?; + let items_sql = items.to_sql()?; + Some(format!("{} IN ({})", inner_sql, items_sql)) + } + // TODO: Implement SQL translations for these expressions if possible + Expr::Agg(..) => None, + Expr::Cast(..) => None, + Expr::Function { .. } => None, + } + } } impl Display for Expr { diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 78d6d6c0d7..3a7a9bd502 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -308,8 +308,18 @@ fn materialize_scan_task( }) .collect::>>() })?, - FileFormatConfig::Database(DatabaseSourceConfig { sql, limit, offset }) => { - let py_expr = scan_task + FileFormatConfig::Database(DatabaseSourceConfig { + sql, + limit, + offset, + limit_before_offset, + }) => { + let predicate_sql = scan_task + .pushdowns + .filters + .as_ref() + .and_then(|p| p.to_sql()); + let predicate_expr = scan_task .pushdowns .filters .as_ref() @@ -322,6 +332,9 @@ fn materialize_scan_task( url, *limit, *offset, + *limit_before_offset, + predicate_sql.clone(), + predicate_expr.clone(), scan_task.schema.clone().into(), scan_task .pushdowns @@ -329,7 +342,6 @@ fn materialize_scan_task( .as_ref() .map(|cols| cols.as_ref().clone()), scan_task.pushdowns.limit, - py_expr.clone(), ) .map(|t| t.into()) .context(PyIOSnafu) diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 178ff701e4..b486b35af9 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -794,43 +794,44 @@ pub(crate) fn read_sql_into_py_table( url: &str, limit: Option, offset: Option, + limit_before_offset: Option, + predicate_sql: Option, + predicate_expr: Option, schema: PySchema, include_columns: Option>, num_rows: Option, - predicate: Option, ) -> PyResult { let py_schema = py .import(pyo3::intern!(py, "daft.logical.schema"))? .getattr(pyo3::intern!(py, "Schema"))? .getattr(pyo3::intern!(py, "_from_pyschema"))? .call1((schema,))?; - let py_predicate = match predicate { - Some(p) => { - let expressions_mod = py.import(pyo3::intern!(py, "daft.expressions.expressions"))?; - Some( - expressions_mod - .getattr(pyo3::intern!(py, "Expression"))? - .getattr(pyo3::intern!(py, "_from_pyexpr"))? - .call1((p,))?, - ) - } + let predicate_pyexpr = match predicate_expr { + Some(p) => Some( + py.import(pyo3::intern!(py, "daft.expressions.expressions"))? + .getattr(pyo3::intern!(py, "Expression"))? + .getattr(pyo3::intern!(py, "_from_pyexpr"))? + .call1((p,))?, + ), None => None, }; + let sql_options = py + .import(pyo3::intern!(py, "daft.runners.partitioning"))? + .getattr(pyo3::intern!(py, "TableReadSQLOptions"))? + .call1(( + limit, + offset, + limit_before_offset, + predicate_sql, + predicate_pyexpr, + ))?; let read_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadOptions"))? .call1((num_rows, include_columns))?; py.import(pyo3::intern!(py, "daft.table.table_io"))? .getattr(pyo3::intern!(py, "read_sql"))? - .call1(( - sql, - url, - py_schema, - limit, - offset, - read_options, - py_predicate, - ))? + .call1((sql, url, py_schema, sql_options, read_options))? .getattr(pyo3::intern!(py, "to_table"))? .call0()? .getattr(pyo3::intern!(py, "_table"))? diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index 1a6c314b84..440d8e55be 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -265,11 +265,22 @@ pub struct DatabaseSourceConfig { pub sql: String, pub limit: Option, pub offset: Option, + pub limit_before_offset: Option, } impl DatabaseSourceConfig { - pub fn new_internal(sql: String, limit: Option, offset: Option) -> Self { - Self { sql, limit, offset } + pub fn new_internal( + sql: String, + limit: Option, + offset: Option, + limit_before_offset: Option, + ) -> Self { + Self { + sql, + limit, + offset, + limit_before_offset, + } } pub fn multiline_display(&self) -> Vec { @@ -281,6 +292,9 @@ impl DatabaseSourceConfig { if let Some(offset) = self.offset { res.push(format!("Offset = {}", offset)); } + if let Some(limit_before_offset) = self.limit_before_offset { + res.push(format!("Limit before offset = {}", limit_before_offset)); + } res } } @@ -290,8 +304,13 @@ impl DatabaseSourceConfig { impl DatabaseSourceConfig { /// Create a config for a Database data source. #[new] - fn new(sql: &str, limit: Option, offset: Option) -> Self { - Self::new_internal(sql.to_string(), limit, offset) + fn new( + sql: &str, + limit: Option, + offset: Option, + limit_before_offset: Option, + ) -> Self { + Self::new_internal(sql.to_string(), limit, offset, limit_before_offset) } } diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index de254ddd77..f0d5b8029a 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -1,50 +1,83 @@ from __future__ import annotations -import random +import tempfile from typing import Generator import numpy as np +import pandas as pd import pytest -import sqlalchemy import tenacity +from sqlalchemy import ( + Column, + Engine, + Float, + Integer, + MetaData, + String, + Table, + create_engine, + text, +) -URLS = {"trino": "trino://user@localhost:8080/tpch"} - +URLS = [ + "trino://user@localhost:8080/memory/default", + "postgresql://username:password@localhost:5432/postgres", + "mysql+pymysql://username:password@localhost:3306/mysql", + "sqlite:///", +] NUM_TEST_ROWS = 200 +NUM_ROWS_PER_PARTITION = 50 +TEST_TABLE_NAME = "example" @pytest.fixture(scope="session") -def test_items(): - np.random.seed(42) +def generated_data() -> pd.DataFrame: data = { - "sepal_length": np.round(np.random.uniform(4.3, 7.9, NUM_TEST_ROWS), 1), - "sepal_width": np.round(np.random.uniform(2.0, 4.4, NUM_TEST_ROWS), 1), - "petal_length": np.round(np.random.uniform(1.0, 6.9, NUM_TEST_ROWS), 1), - "petal_width": np.round(np.random.uniform(0.1, 2.5, NUM_TEST_ROWS), 1), - "variety": [random.choice(["Setosa", "Versicolor", "Virginica"]) for _ in range(NUM_TEST_ROWS)], + "id": np.arange(NUM_TEST_ROWS), + "sepal_length": np.arange(NUM_TEST_ROWS, dtype=float), + "sepal_width": np.arange(NUM_TEST_ROWS, dtype=float), + "petal_length": np.arange(NUM_TEST_ROWS, dtype=float), + "petal_width": np.arange(NUM_TEST_ROWS, dtype=float), + "variety": ["setosa"] * 50 + ["versicolor"] * 50 + ["virginica"] * 50 + [None] * 50, } - return data + return pd.DataFrame(data) -@tenacity.retry( - stop=tenacity.stop_after_delay(60), - wait=tenacity.wait_fixed(5), - reraise=True, -) -def check_database_connection(url) -> None: - with sqlalchemy.create_engine(url).connect() as conn: - conn.execute(sqlalchemy.text("SELECT * FROM tpch.sf1.nation")) +@pytest.fixture(scope="session", params=URLS) +def test_db(request: pytest.FixtureRequest, generated_data: pd.DataFrame) -> Generator[str, None, None]: + db_url = request.param + if db_url.startswith("sqlite"): + with tempfile.NamedTemporaryFile(suffix=".db") as file: + db_url += file.name + setup_database(db_url, generated_data) + yield db_url + else: + setup_database(db_url, generated_data) + yield db_url -@pytest.fixture(scope="session") -def db_url() -> Generator[str, None, None]: - for url in URLS.values(): - try: - check_database_connection(url) - except Exception as e: - pytest.fail(f"Failed to connect to {url}: {e}") +@tenacity.retry(stop=tenacity.stop_after_delay(10), wait=tenacity.wait_fixed(5), reraise=True) +def setup_database(db_url: str, data: pd.DataFrame) -> None: + engine = create_engine(db_url) + create_and_populate(engine, data) + + # Ensure the table is created and populated + with engine.connect() as conn: + result = conn.execute(text(f"SELECT COUNT(*) FROM {TEST_TABLE_NAME}")).fetchone()[0] + assert result == NUM_TEST_ROWS - def db_url(db): - return URLS[db] - yield db_url +def create_and_populate(engine: Engine, data: pd.DataFrame) -> None: + metadata = MetaData() + table = Table( + TEST_TABLE_NAME, + metadata, + Column("id", Integer), + Column("sepal_length", Float), + Column("sepal_width", Float), + Column("petal_length", Float), + Column("petal_width", Float), + Column("variety", String(50)), + ) + metadata.create_all(engine) + data.to_sql(table.name, con=engine, if_exists="replace", index=False) diff --git a/tests/integration/sql/docker-compose/docker-compose.yml b/tests/integration/sql/docker-compose/docker-compose.yml index e463e51eeb..11c391b0d3 100644 --- a/tests/integration/sql/docker-compose/docker-compose.yml +++ b/tests/integration/sql/docker-compose/docker-compose.yml @@ -5,4 +5,32 @@ services: container_name: trino ports: - 8080:8080 - restart: unless-stopped + + postgres: + image: postgres:latest + container_name: postgres + environment: + POSTGRES_DB: postgres + POSTGRES_USER: username + POSTGRES_PASSWORD: password + ports: + - 5432:5432 + volumes: + - postgres_data:/var/lib/postgresql/data + + mysql: + image: mysql:latest + container_name: mysql + environment: + MYSQL_DATABASE: mysql + MYSQL_USER: username + MYSQL_PASSWORD: password + MYSQL_ROOT_PASSWORD: rootpassword + ports: + - 3306:3306 + volumes: + - mysql_data:/var/lib/mysql + +volumes: + postgres_data: + mysql_data: diff --git a/tests/integration/sql/test_databases.py b/tests/integration/sql/test_databases.py deleted file mode 100644 index 28faa79321..0000000000 --- a/tests/integration/sql/test_databases.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -import pandas as pd -import pytest - -import daft - - -@pytest.mark.integration() -def test_trino_create_dataframe_ok(db_url) -> None: - url = db_url("trino") - df = daft.read_sql("SELECT * FROM tpch.sf1.nation", url) - pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", url) - - assert df.to_pandas().equals(pd_df) diff --git a/tests/integration/sql/test_operations.py b/tests/integration/sql/test_operations.py deleted file mode 100644 index c62df97214..0000000000 --- a/tests/integration/sql/test_operations.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations - -import sqlite3 -import tempfile - -import pandas as pd -import pytest - -import daft - - -# Fixture for temporary SQLite database -@pytest.fixture(scope="module") -def temp_sqllite_db(test_items): - data = list( - zip( - test_items["sepal_length"], - test_items["sepal_width"], - test_items["petal_length"], - test_items["petal_width"], - test_items["variety"], - ) - ) - with tempfile.NamedTemporaryFile(suffix=".db") as file: - connection = sqlite3.connect(file.name) - connection.execute( - "CREATE TABLE iris (sepal_length REAL, sepal_width REAL, petal_length REAL, petal_width REAL, variety TEXT)" - ) - connection.executemany("INSERT INTO iris VALUES (?, ?, ?, ?, ?)", data) - connection.commit() - connection.close() - yield file.name - - -@pytest.mark.integration() -def test_sqllite_create_dataframe_ok(temp_sqllite_db) -> None: - df = daft.read_sql( - "SELECT * FROM iris", f"sqlite://{temp_sqllite_db}" - ) # path here only has 2 slashes instead of 3 because connectorx uses 2 slashes - pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") - - assert df.to_pandas().equals(pd_df) - - -@pytest.mark.integration() -@pytest.mark.parametrize("num_partitions", [1, 2, 3]) -def test_sqllite_partitioned_read(temp_sqllite_db, num_partitions) -> None: - df = daft.read_sql(f"SELECT * FROM iris LIMIT {50 * num_partitions}", f"sqlite://{temp_sqllite_db}") - assert df.num_partitions() == num_partitions - df = df.collect() - assert len(df) == 50 * num_partitions - - # test with a number of rows that is not a multiple of 50 - df = daft.read_sql(f"SELECT * FROM iris LIMIT {50 * num_partitions + 1}", f"sqlite://{temp_sqllite_db}") - assert df.num_partitions() == num_partitions + 1 - df = df.collect() - assert len(df) == 50 * num_partitions + 1 - - -@pytest.mark.integration() -def test_sqllite_read_with_filter_pushdowns(temp_sqllite_db) -> None: - df = daft.read_sql("SELECT * FROM iris", f"sqlite://{temp_sqllite_db}") - df = df.where(df["sepal_length"] > 5.0) - df = df.where(df["sepal_width"] > 3.0) - - pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") - pd_df = pd_df[pd_df["sepal_length"] > 5.0] - pd_df = pd_df[pd_df["sepal_width"] > 3.0] - - df = df.to_pandas().sort_values("sepal_length", ascending=False).reset_index(drop=True) - pd_df = pd_df.sort_values("sepal_length", ascending=False).reset_index(drop=True) - assert df.equals(pd_df) - - -@pytest.mark.integration() -def test_sqllite_read_with_limit_pushdown(temp_sqllite_db) -> None: - df = daft.read_sql("SELECT * FROM iris", f"sqlite://{temp_sqllite_db}") - df = df.limit(100) - - pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") - pd_df = pd_df.head(100) - - df = df.to_pandas() - pd_df = pd_df.reset_index(drop=True) - assert df.equals(pd_df) - - -@pytest.mark.integration() -def test_sqllite_read_with_projection_pushdown(temp_sqllite_db) -> None: - df = daft.read_sql("SELECT * FROM iris", f"sqlite://{temp_sqllite_db}") - df = df.select(df["sepal_length"], df["variety"]) - - pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") - pd_df = pd_df[["sepal_length", "variety"]] - - df = df.to_pandas() - assert df.equals(pd_df) - - -@pytest.mark.integration() -def test_sqllite_read_with_all_pushdowns(temp_sqllite_db) -> None: - df = daft.read_sql("SELECT * FROM iris", f"sqlite://{temp_sqllite_db}") - df = df.where(df["sepal_length"] > 5.0) - df = df.where(df["sepal_width"] > 3.0) - df = df.limit(100) - df = df.select(df["sepal_length"]) - - pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") - pd_df = pd_df[pd_df["sepal_length"] > 5.0] - pd_df = pd_df[pd_df["sepal_width"] > 3.0] - pd_df = pd_df.head(100) - pd_df = pd_df[["sepal_length"]] - - df = df.to_pandas().sort_values("sepal_length", ascending=False).reset_index(drop=True) - pd_df = pd_df.sort_values("sepal_length", ascending=False).reset_index(drop=True) - assert df.equals(pd_df) - - -@pytest.mark.integration() -def test_sqllite_bad_url() -> None: - with pytest.raises(RuntimeError, match="Failed to execute sql"): - daft.read_sql("SELECT * FROM iris", "sqlite://") diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py new file mode 100644 index 0000000000..9ea6cda088 --- /dev/null +++ b/tests/integration/sql/test_sql.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import pytest + +import daft +from tests.conftest import assert_df_equals +from tests.integration.sql.conftest import ( + NUM_ROWS_PER_PARTITION, + NUM_TEST_ROWS, + TEST_TABLE_NAME, +) + + +@pytest.mark.integration() +def test_sql_create_dataframe_ok(test_db, generated_data) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + + assert_df_equals(df.to_pandas(), generated_data, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2, 3]) +def test_sql_partitioned_read(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} LIMIT {NUM_ROWS_PER_PARTITION * num_partitions}", test_db) + assert df.num_partitions() == num_partitions + df = df.collect() + assert len(df) == NUM_ROWS_PER_PARTITION * num_partitions + + # test with a number of rows that is not a multiple of NUM_ROWS_PER_PARTITION + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} LIMIT {NUM_ROWS_PER_PARTITION * num_partitions + 1}", test_db) + assert df.num_partitions() == num_partitions + 1 + df = df.collect() + assert len(df) == NUM_ROWS_PER_PARTITION * num_partitions + 1 + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "column,operator,value,expected_length", + [ + ("id", ">", 100, 99), + ("id", "<=", 100, 101), + ("sepal_length", "<", 100.0, 100), + ("sepal_length", ">=", 100.0, 100), + ("variety", "=", "setosa", 50), + ("variety", "!=", "setosa", 100), + ("variety", "is_null", None, 50), + ("variety", "not_null", None, 150), + ("variety", "is_in", ["setosa", "versicolor"], 100), + ("id", "is_in", [1, 2, 3], 3), + ], +) +def test_sql_read_with_filter_pushdowns(test_db, column, operator, value, expected_length) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + + if operator == ">": + df = df.where(df[column] > value) + elif operator == "<": + df = df.where(df[column] < value) + elif operator == "=": + df = df.where(df[column] == value) + elif operator == "!=": + df = df.where(df[column] != value) + elif operator == ">=": + df = df.where(df[column] >= value) + elif operator == "<=": + df = df.where(df[column] <= value) + elif operator == "is_null": + df = df.where(df[column].is_null()) + elif operator == "not_null": + df = df.where(df[column].not_null()) + elif operator == "is_in": + df = df.where(df[column].is_in(value)) + + df = df.collect() + assert len(df) == expected_length + + +@pytest.mark.integration() +def test_sql_read_with_if_else_filter_pushdown(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df = df.where(df["variety"] == (df["sepal_length"] >= 125).if_else("virginica", "setosa")) + + df = df.collect() + assert len(df) == 75 + + +@pytest.mark.integration() +def test_sql_read_with_all_pushdowns(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df = df.where(df["sepal_length"] > 100) + df = df.select(df["sepal_length"], df["variety"]) + df = df.limit(50) + + df = df.collect() + assert df.column_names == ["sepal_length", "variety"] + assert len(df) == 50 + + +@pytest.mark.integration() +def test_sql_read_with_limit_pushdown(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df = df.limit(100) + + df = df.collect() + assert len(df) == 100 + + +@pytest.mark.integration() +def test_sql_read_with_projection_pushdown(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df = df.select(df["sepal_length"], df["variety"]) + + df = df.collect() + assert df.column_names == ["sepal_length", "variety"] + assert len(df) == NUM_TEST_ROWS + + +@pytest.mark.integration() +def test_sql_bad_url() -> None: + with pytest.raises(RuntimeError, match="Failed to execute sql"): + daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", "bad_url://") From 2d518647b71f65c9bcf2f01ce765d41d065da756 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 26 Feb 2024 09:35:39 -0800 Subject: [PATCH 15/30] cleanup --- .github/workflows/python-package.yml | 2 +- daft/iceberg/iceberg_scan.py | 1 - daft/utils.py | 3 -- src/daft-dsl/src/expr.rs | 1 - src/daft-micropartition/src/micropartition.rs | 2 +- src/daft-plan/src/planner.rs | 4 ++- src/daft-scan/src/lib.rs | 36 ++++--------------- src/daft-scan/src/python.rs | 2 +- src/daft-scan/src/scan_task_iters.rs | 8 +---- tests/integration/sql/conftest.py | 1 + tests/integration/sql/test_sql.py | 7 ++-- 11 files changed, 18 insertions(+), 49 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 45db70102d..2ea4c1aaf5 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -551,7 +551,7 @@ jobs: "type": "section", "text": { "type": "mrkdwn", - "text": ":rotating_light: [CI] Iceberg Integration Tests <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|workflow> *FAILED on main* :rotating_light:" + "text": ":rotating_light: [CI] SQL Integration Tests <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|workflow> *FAILED on main* :rotating_light:" } } ] diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index 54eb6f7da0..4b583bcce0 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -170,7 +170,6 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: continue rows_left -= record_count scan_tasks.append(st) - return iter(scan_tasks) def can_absorb_filter(self) -> bool: diff --git a/daft/utils.py b/daft/utils.py index 8580323f42..a8efc9bf2e 100644 --- a/daft/utils.py +++ b/daft/utils.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import pickle import random import statistics @@ -8,8 +7,6 @@ import pyarrow as pa -logger = logging.getLogger(__name__) - ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 352cb6f535..1e83a8fca9 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -623,7 +623,6 @@ impl Expr { let inner_sql = inner.to_sql()?; Some(format!("{} IS NOT NULL", inner_sql)) } - // TODO: Implement SQL translations for these expressions Expr::IfElse { if_true, if_false, diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 3a7a9bd502..b4b008cc9f 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -229,7 +229,7 @@ fn materialize_scan_task( } FileFormatConfig::Database(_) => { return Err(common_error::DaftError::TypeError( - "Native reads for Database file format not yet implemented".to_string(), + "Native reads for Database file format not implemented".to_string(), )) .context(DaftCoreComputeSnafu); } diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 5776998d39..9c7b667bd9 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -715,7 +715,9 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe input_physical.into(), ))) } - _ => unimplemented!(), + FileFormat::Database => Err(common_error::DaftError::ValueError( + "Database sink not yet implemented".to_string(), + )), } } } diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index d275179f3f..1f7023a2b8 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -135,46 +135,32 @@ pub enum DataFileSource { partition_spec: PartitionSpec, statistics: Option, }, - DatabaseDataFile { - path: String, - chunk_spec: Option, - size_bytes: Option, - metadata: Option, - partition_spec: Option, - statistics: Option, - }, } impl DataFileSource { pub fn get_path(&self) -> &str { match self { - Self::AnonymousDataFile { path, .. } - | Self::CatalogDataFile { path, .. } - | Self::DatabaseDataFile { path, .. } => path, + Self::AnonymousDataFile { path, .. } | Self::CatalogDataFile { path, .. } => path, } } pub fn get_chunk_spec(&self) -> Option<&ChunkSpec> { match self { Self::AnonymousDataFile { chunk_spec, .. } - | Self::CatalogDataFile { chunk_spec, .. } - | Self::DatabaseDataFile { chunk_spec, .. } => chunk_spec.as_ref(), + | Self::CatalogDataFile { chunk_spec, .. } => chunk_spec.as_ref(), } } pub fn get_size_bytes(&self) -> Option { match self { Self::AnonymousDataFile { size_bytes, .. } - | Self::CatalogDataFile { size_bytes, .. } - | Self::DatabaseDataFile { size_bytes, .. } => *size_bytes, + | Self::CatalogDataFile { size_bytes, .. } => *size_bytes, } } pub fn get_metadata(&self) -> Option<&TableMetadata> { match self { - Self::AnonymousDataFile { metadata, .. } | Self::DatabaseDataFile { metadata, .. } => { - metadata.as_ref() - } + Self::AnonymousDataFile { metadata, .. } => metadata.as_ref(), Self::CatalogDataFile { metadata, .. } => Some(metadata), } } @@ -182,15 +168,13 @@ impl DataFileSource { pub fn get_statistics(&self) -> Option<&TableStatistics> { match self { Self::AnonymousDataFile { statistics, .. } - | Self::CatalogDataFile { statistics, .. } - | Self::DatabaseDataFile { statistics, .. } => statistics.as_ref(), + | Self::CatalogDataFile { statistics, .. } => statistics.as_ref(), } } pub fn get_partition_spec(&self) -> Option<&PartitionSpec> { match self { - Self::AnonymousDataFile { partition_spec, .. } - | Self::DatabaseDataFile { partition_spec, .. } => partition_spec.as_ref(), + Self::AnonymousDataFile { partition_spec, .. } => partition_spec.as_ref(), Self::CatalogDataFile { partition_spec, .. } => Some(partition_spec), } } @@ -205,14 +189,6 @@ impl DataFileSource { metadata, partition_spec, statistics, - } - | Self::DatabaseDataFile { - path, - chunk_spec, - size_bytes, - metadata, - partition_spec, - statistics, } => { res.push(format!("Path = {}", path)); if let Some(chunk_spec) = chunk_spec { diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 5edb252719..a7ea077573 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -327,7 +327,7 @@ pub mod pylib { storage_config: PyStorageConfig, pushdowns: Option, ) -> PyResult { - let data_source = DataFileSource::DatabaseDataFile { + let data_source = DataFileSource::AnonymousDataFile { path: url, chunk_spec: None, size_bytes: None, diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index a627d89105..fe3d470a2f 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use common_error::{DaftError, DaftResult}; +use common_error::DaftResult; use daft_io::IOStatsContext; use daft_parquet::read::read_parquet_metadata; @@ -191,12 +191,6 @@ pub fn split_by_row_groups( curr_row_groups = Vec::new(); curr_size_bytes = 0; } - DataFileSource::DatabaseDataFile { .. } => { - return Err(DaftError::ValueError( - "Cannot split by row groups for database sources" - .to_string(), - )); - } }; new_tasks.push(Ok(ScanTask::new( diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index f0d5b8029a..a2864fac39 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -47,6 +47,7 @@ def generated_data() -> pd.DataFrame: def test_db(request: pytest.FixtureRequest, generated_data: pd.DataFrame) -> Generator[str, None, None]: db_url = request.param if db_url.startswith("sqlite"): + # No docker container for sqlite, so we need to create a temporary file with tempfile.NamedTemporaryFile(suffix=".db") as file: db_url += file.name setup_database(db_url, generated_data) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 9ea6cda088..84d814c9dd 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -97,12 +97,13 @@ def test_sql_read_with_all_pushdowns(test_db) -> None: @pytest.mark.integration() -def test_sql_read_with_limit_pushdown(test_db) -> None: +@pytest.mark.parametrize("limit", [0, 1, 10, 100, 200]) +def test_sql_read_with_limit_pushdown(test_db, limit) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) - df = df.limit(100) + df = df.limit(limit) df = df.collect() - assert len(df) == 100 + assert len(df) == limit @pytest.mark.integration() From f53bbb2847b02705f3e603c65c5d42131afa25a7 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 28 Feb 2024 14:14:40 -0800 Subject: [PATCH 16/30] everything except limit 0 --- daft/context.py | 3 ++ daft/daft.pyi | 6 +++ daft/io/_sql.py | 19 +++++--- daft/logical/schema.py | 3 ++ daft/sql/sql_reader.py | 56 ++++++++++++---------- daft/sql/sql_scan.py | 69 +++++++++++++++++----------- src/common/daft-config/src/lib.rs | 2 + src/common/daft-config/src/python.rs | 9 ++++ src/daft-core/src/python/schema.rs | 4 ++ src/daft-dsl/src/expr.rs | 48 ++++++++++--------- src/daft-scan/src/lib.rs | 61 +++++++++++++++++++++--- src/daft-scan/src/python.rs | 10 ++-- src/daft-scan/src/scan_task_iters.rs | 2 +- tests/integration/sql/conftest.py | 26 ++++++----- tests/integration/sql/test_sql.py | 49 +++++++++++--------- 15 files changed, 243 insertions(+), 124 deletions(-) diff --git a/daft/context.py b/daft/context.py index 48ef44df94..0a8ea18f25 100644 --- a/daft/context.py +++ b/daft/context.py @@ -254,6 +254,7 @@ def set_execution_config( csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + read_sql_partition_size_bytes: int | None = None, ) -> DaftContext: """Globally sets various configuration parameters which control various aspects of Daft execution. These configuration values are used when a Dataframe is executed (e.g. calls to `.write_*`, `.collect()` or `.show()`) @@ -283,6 +284,7 @@ def set_execution_config( csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 shuffle_aggregation_default_partitions: Minimum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. + read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB """ # Replace values in the DaftExecutionConfig with user-specified overrides ctx = get_context() @@ -302,6 +304,7 @@ def set_execution_config( csv_target_filesize=csv_target_filesize, csv_inflation_factor=csv_inflation_factor, shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions, + read_sql_partition_size_bytes=read_sql_partition_size_bytes, ) ctx._daft_execution_config = new_daft_execution_config diff --git a/daft/daft.pyi b/daft/daft.pyi index bfc7bfd2b6..443cda3911 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -608,7 +608,9 @@ class ScanTask: url: str, file_format: FileFormatConfig, schema: PySchema, + num_rows: int, storage_config: StorageConfig, + size_bytes: int | None, pushdowns: Pushdowns | None, ) -> ScanTask: """ @@ -858,6 +860,7 @@ class PySchema: def names(self) -> list[str]: ... def union(self, other: PySchema) -> PySchema: ... def eq(self, other: PySchema) -> bool: ... + def estimate_row_size_bytes(self) -> float: ... @staticmethod def from_field_name_and_types(names_and_types: list[tuple[str, PyDataType]]) -> PySchema: ... @staticmethod @@ -1250,6 +1253,7 @@ class PyDaftExecutionConfig: csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + read_sql_partition_size_bytes: int | None = None, ) -> PyDaftExecutionConfig: ... @property def scan_tasks_min_size_bytes(self) -> int: ... @@ -1275,6 +1279,8 @@ class PyDaftExecutionConfig: def csv_inflation_factor(self) -> float: ... @property def shuffle_aggregation_default_partitions(self) -> int: ... + @property + def read_sql_partition_size_bytes(self) -> int: ... class PyDaftPlanningConfig: def with_config_values( diff --git a/daft/io/_sql.py b/daft/io/_sql.py index a3fee03913..cf20f677e2 100644 --- a/daft/io/_sql.py +++ b/daft/io/_sql.py @@ -1,6 +1,8 @@ # isort: dont-add-import: from __future__ import annotations +from typing import Optional + from daft import context from daft.api_annotations import PublicAPI from daft.daft import PythonStorageConfig, ScanOperatorHandle, StorageConfig @@ -10,10 +12,7 @@ @PublicAPI -def read_sql( - sql: str, - url: str, -) -> DataFrame: +def read_sql(sql: str, url: str, num_partitions: Optional[int] = None) -> DataFrame: """Creates a DataFrame from a SQL query Example: @@ -22,15 +21,21 @@ def read_sql( Args: sql (str): SQL query to execute url (str): URL to the database + num_partitions (Optional[int]): Number of partitions to read the data into, + defaults to None, which will lets Daft determine the number of partitions. returns: - DataFrame: parsed DataFrame + DataFrame: Dataframe containing the results of the query """ io_config = context.get_context().daft_planning_config.default_io_config storage_config = StorageConfig.python(PythonStorageConfig(io_config)) - sql_operator = SQLScanOperator(sql, url, storage_config=storage_config) + sql_operator = SQLScanOperator(sql, url, storage_config=storage_config, num_partitions=num_partitions) handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) - return DataFrame(builder) + + if num_partitions is not None and num_partitions > 1 and not sql_operator._limit_and_offset_supported: + return DataFrame(builder).into_partitions(num_partitions) + else: + return DataFrame(builder) diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 2655885bb9..b9334acce8 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -116,6 +116,9 @@ def __len__(self) -> int: def column_names(self) -> list[str]: return list(self._schema.names()) + def estimate_row_size_bytes(self) -> float: + return self._schema.estimate_row_size_bytes() + def __iter__(self) -> Iterator[Field]: col_names = self.column_names() yield from (self[name] for name in col_names) diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index 6f11c0bb93..8bb8b9b3ad 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from urllib.parse import urlparse import pyarrow as pa @@ -29,44 +30,50 @@ def __init__( self.projection = projection self.predicate = predicate - def get_num_rows(self) -> int: - sql = f"SELECT COUNT(*) FROM ({self.sql}) AS subquery" - pa_table = self._execute_sql_query(sql) - return pa_table.column(0)[0].as_py() - def read(self) -> pa.Table: sql = self._construct_sql_query() return self._execute_sql_query(sql) def _construct_sql_query(self) -> str: + clauses = ["SELECT"] if self.projection is not None: - columns = ", ".join(self.projection) + clauses.append(", ".join(self.projection)) else: - columns = "*" + clauses.append("*") - sql = f"SELECT {columns} FROM ({self.sql}) AS subquery" + clauses.append(f"FROM ({self.sql}) AS subquery") if self.predicate is not None: - sql += f" WHERE {self.predicate}" + clauses.append(f" WHERE {self.predicate}") if self.limit is not None and self.offset is not None: - if self.limit_before_offset: - sql += f" LIMIT {self.limit} OFFSET {self.offset}" + if self.limit_before_offset is True: + clauses.append(f" LIMIT {self.limit} OFFSET {self.offset}") else: - sql += f" OFFSET {self.offset} LIMIT {self.limit}" + clauses.append(f" OFFSET {self.offset} LIMIT {self.limit}") elif self.limit is not None: - sql += f" LIMIT {self.limit}" + clauses.append(f" LIMIT {self.limit}") elif self.offset is not None: - sql += f" OFFSET {self.offset}" + clauses.append(f" OFFSET {self.offset}") - return sql + return "\n".join(clauses) def _execute_sql_query(self, sql: str) -> pa.Table: # Supported DBs extracted from here https://github.com/sfu-db/connector-x/tree/7b3147436b7e20b96691348143d605e2249d6119?tab=readme-ov-file#sources - supported_dbs = {"postgres", "mysql", "mssql", "oracle", "bigquery", "sqlite", "clickhouse", "redshift"} - - # Extract the database type from the URL - db_type = self.url.split("://")[0] + supported_dbs = { + "postgres", + "postgresql", + "mysql", + "mssql", + "oracle", + "bigquery", + "sqlite", + "clickhouse", + "redshift", + } + + db_type = urlparse(self.url).scheme + db_type = db_type.strip().lower() # Check if the database type is supported if db_type in supported_dbs: @@ -85,15 +92,18 @@ def _execute_sql_query_with_connectorx(self, sql: str) -> pa.Table: raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e def _execute_sql_query_with_sqlalchemy(self, sql: str) -> pa.Table: - import pandas as pd from sqlalchemy import create_engine, text logger.info(f"Using sqlalchemy to execute sql: {sql}") try: with create_engine(self.url).connect() as connection: result = connection.execute(text(sql)) - df = pd.DataFrame(result.fetchall(), columns=result.keys()) - table = pa.Table.from_pandas(df) - return table + cursor = result.cursor + + columns = [column_description[0] for column_description in cursor.description] + rows = result.fetchall() + pydict = {column: [row[i] for row in rows] for i, column in enumerate(columns)} + + return pa.Table.from_pydict(pydict) except Exception as e: raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index e4089401ee..48b3e74d1a 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -3,6 +3,7 @@ import math from collections.abc import Iterator +from daft.context import get_context from daft.daft import ( DatabaseSourceConfig, FileFormatConfig, @@ -16,43 +17,46 @@ class SQLScanOperator(ScanOperator): - MIN_ROWS_PER_SCAN_TASK = 50 # TODO: Would be better to have a memory limit instead of a row limit - def __init__( self, sql: str, url: str, storage_config: StorageConfig, + num_partitions: int | None = None, ) -> None: super().__init__() self.sql = sql self.url = url self.storage_config = storage_config - self._limit_and_offset_supported, self._limit_before_offset = self._check_limit_and_offset_supported() - self._schema = self._get_schema() - - def _check_limit_and_offset_supported(self) -> tuple[bool, bool]: - try: - # Try to read 1 row with limit before offset - SQLReader(self.sql, self.url, limit=1, offset=0, limit_before_offset=True).read() - return (True, True) - except Exception: + self._num_partitions = num_partitions + self._initialize_schema_and_features() + + def _initialize_schema_and_features(self) -> None: + self._schema, self._limit_and_offset_supported, self._limit_before_offset = self._attempt_schema_read() + + def _attempt_schema_read(self) -> tuple[Schema, bool, bool]: + for limit_before_offset in [True, False]: try: - # Try to read 1 row with limit after offset - SQLReader(self.sql, self.url, offset=0, limit=1, limit_before_offset=False).read() - return (True, False) + pa_table = SQLReader( + self.sql, self.url, limit=1, offset=0, limit_before_offset=limit_before_offset + ).read() + schema = Schema.from_pyarrow_schema(pa_table.schema) + return schema, True, limit_before_offset except Exception: - # If both fail, then limit and offset are not supported - return (False, False) - - def _get_schema(self) -> Schema: - if self._limit_and_offset_supported: - pa_table = SQLReader( - self.sql, self.url, limit=1, offset=0, limit_before_offset=self._limit_before_offset - ).read() - else: - pa_table = SQLReader(self.sql, self.url).read() - return Schema.from_pyarrow_schema(pa_table.schema) + continue + + # If both attempts fail, read without limit and offset + pa_table = SQLReader(self.sql, self.url).read() + schema = Schema.from_pyarrow_schema(pa_table.schema) + return schema, False, False + + def _get_num_rows(self) -> int: + pa_table = SQLReader( + self.sql, + self.url, + projection=["COUNT(*)"], + ).read() + return pa_table.column(0)[0].as_py() def schema(self) -> Schema: return self._schema @@ -70,6 +74,10 @@ def multiline_display(self) -> list[str]: ] def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: + total_rows = self._get_num_rows() + estimate_row_size_bytes = math.ceil(self.schema().estimate_row_size_bytes()) + total_size = total_rows * estimate_row_size_bytes + if not self._limit_and_offset_supported: # If limit and offset are not supported, then we can't parallelize the scan, so we just return a single scan task file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) @@ -79,14 +87,19 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: url=self.url, file_format=file_format_config, schema=self._schema._schema, + num_rows=total_rows, storage_config=self.storage_config, + size_bytes=total_size, pushdowns=pushdowns, ) ] ) - total_rows = SQLReader(self.sql, self.url).get_num_rows() - num_scan_tasks = math.ceil(total_rows / self.MIN_ROWS_PER_SCAN_TASK) + num_scan_tasks = ( + math.ceil(total_size / get_context().daft_execution_config.read_sql_partition_size_bytes) + if self._num_partitions is None + else self._num_partitions + ) num_rows_per_scan_task = total_rows // num_scan_tasks scan_tasks = [] @@ -103,7 +116,9 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: url=self.url, file_format=file_format_config, schema=self._schema._schema, + num_rows=limit, storage_config=self.storage_config, + size_bytes=limit * estimate_row_size_bytes, pushdowns=pushdowns, ) ) diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index a3dd531d3a..ac53ad30a5 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -35,6 +35,7 @@ pub struct DaftExecutionConfig { pub csv_target_filesize: usize, pub csv_inflation_factor: f64, pub shuffle_aggregation_default_partitions: usize, + pub read_sql_partition_size_bytes: usize, } impl Default for DaftExecutionConfig { @@ -53,6 +54,7 @@ impl Default for DaftExecutionConfig { csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, shuffle_aggregation_default_partitions: 200, + read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB } } } diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 8d2124c7ba..9071c878d1 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -90,6 +90,7 @@ impl PyDaftExecutionConfig { csv_target_filesize: Option, csv_inflation_factor: Option, shuffle_aggregation_default_partitions: Option, + read_sql_partition_size_bytes: Option, ) -> PyResult { let mut config = self.config.as_ref().clone(); @@ -136,6 +137,9 @@ impl PyDaftExecutionConfig { { config.shuffle_aggregation_default_partitions = shuffle_aggregation_default_partitions; } + if let Some(read_sql_partition_size_bytes) = read_sql_partition_size_bytes { + config.read_sql_partition_size_bytes = read_sql_partition_size_bytes; + } Ok(PyDaftExecutionConfig { config: Arc::new(config), @@ -202,6 +206,11 @@ impl PyDaftExecutionConfig { Ok(self.config.shuffle_aggregation_default_partitions) } + #[getter] + fn get_read_sql_partition_size_bytes(&self) -> PyResult { + Ok(self.config.read_sql_partition_size_bytes) + } + fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec,))> { let bin_data = bincode::serialize(self.config.as_ref()) .expect("DaftExecutionConfig should be serializable to bytes"); diff --git a/src/daft-core/src/python/schema.rs b/src/daft-core/src/python/schema.rs index 975aeffc5c..74c9dc1e81 100644 --- a/src/daft-core/src/python/schema.rs +++ b/src/daft-core/src/python/schema.rs @@ -37,6 +37,10 @@ impl PySchema { Ok(self.schema.fields.eq(&other.schema.fields)) } + pub fn estimate_row_size_bytes(&self) -> PyResult { + Ok(self.schema.estimate_row_size_bytes()) + } + #[staticmethod] pub fn from_field_name_and_types( names_and_types: Vec<(String, PyDataType)>, diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 1e83a8fca9..69bc8acbb6 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -570,6 +570,10 @@ impl Expr { } pub fn to_sql(&self) -> Option { + self.to_sql_inner() + } + + fn to_sql_inner(&self) -> Option { match self { Expr::Column(name) => Some(name.to_string()), Expr::Literal(lit) => match lit { @@ -591,13 +595,10 @@ impl Expr { }, _ => Some(lit.to_string().replace('\"', "'")), }, - Expr::Alias(inner, ..) => { - let inner_sql = inner.to_sql()?; - Some(inner_sql) - } + Expr::Alias(inner, ..) => inner.to_sql_inner(), Expr::BinaryOp { op, left, right } => { - let left_sql = left.to_sql()?; - let right_sql = right.to_sql()?; + let left_sql = left.to_sql_inner()?; + let right_sql = right.to_sql_inner()?; let op = match op { Operator::Eq => "=".to_string(), Operator::NotEq => "!=".to_string(), @@ -609,36 +610,33 @@ impl Expr { Operator::Or => "OR".to_string(), _ => return None, }; - Some(format!("{} {} {}", left_sql, op, right_sql)) - } - Expr::Not(inner) => { - let inner_sql = inner.to_sql()?; - Some(format!("NOT {}", inner_sql)) - } - Expr::IsNull(inner) => { - let inner_sql = inner.to_sql()?; - Some(format!("{} IS NULL", inner_sql)) - } - Expr::NotNull(inner) => { - let inner_sql = inner.to_sql()?; - Some(format!("{} IS NOT NULL", inner_sql)) - } + Some(format!("({}) {} ({})", left_sql, op, right_sql)) + } + Expr::Not(inner) => inner + .to_sql_inner() + .map(|inner_sql| format!("NOT ({})", inner_sql)), + Expr::IsNull(inner) => inner + .to_sql_inner() + .map(|inner_sql| format!("({}) IS NULL", inner_sql)), + Expr::NotNull(inner) => inner + .to_sql_inner() + .map(|inner_sql| format!("({}) IS NOT NULL", inner_sql)), Expr::IfElse { if_true, if_false, predicate, } => { - let if_true_sql = if_true.to_sql()?; - let if_false_sql = if_false.to_sql()?; - let predicate_sql = predicate.to_sql()?; + let if_true_sql = if_true.to_sql_inner()?; + let if_false_sql = if_false.to_sql_inner()?; + let predicate_sql = predicate.to_sql_inner()?; Some(format!( "CASE WHEN {} THEN {} ELSE {} END", predicate_sql, if_true_sql, if_false_sql )) } Expr::IsIn(inner, items) => { - let inner_sql = inner.to_sql()?; - let items_sql = items.to_sql()?; + let inner_sql = inner.to_sql_inner()?; + let items_sql = items.to_sql_inner()?; Some(format!("{} IN ({})", inner_sql, items_sql)) } // TODO: Implement SQL translations for these expressions if possible diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 1f7023a2b8..d730528dc3 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -135,46 +135,62 @@ pub enum DataFileSource { partition_spec: PartitionSpec, statistics: Option, }, + DatabaseDataSource { + path: String, + chunk_spec: Option, + size_bytes: Option, + metadata: TableMetadata, + partition_spec: Option, + statistics: Option, + }, } impl DataFileSource { pub fn get_path(&self) -> &str { match self { - Self::AnonymousDataFile { path, .. } | Self::CatalogDataFile { path, .. } => path, + Self::AnonymousDataFile { path, .. } + | Self::CatalogDataFile { path, .. } + | Self::DatabaseDataSource { path, .. } => path, } } pub fn get_chunk_spec(&self) -> Option<&ChunkSpec> { match self { Self::AnonymousDataFile { chunk_spec, .. } - | Self::CatalogDataFile { chunk_spec, .. } => chunk_spec.as_ref(), + | Self::CatalogDataFile { chunk_spec, .. } + | Self::DatabaseDataSource { chunk_spec, .. } => chunk_spec.as_ref(), } } pub fn get_size_bytes(&self) -> Option { match self { Self::AnonymousDataFile { size_bytes, .. } - | Self::CatalogDataFile { size_bytes, .. } => *size_bytes, + | Self::CatalogDataFile { size_bytes, .. } + | Self::DatabaseDataSource { size_bytes, .. } => *size_bytes, } } pub fn get_metadata(&self) -> Option<&TableMetadata> { match self { Self::AnonymousDataFile { metadata, .. } => metadata.as_ref(), - Self::CatalogDataFile { metadata, .. } => Some(metadata), + Self::CatalogDataFile { metadata, .. } | Self::DatabaseDataSource { metadata, .. } => { + Some(metadata) + } } } pub fn get_statistics(&self) -> Option<&TableStatistics> { match self { Self::AnonymousDataFile { statistics, .. } - | Self::CatalogDataFile { statistics, .. } => statistics.as_ref(), + | Self::CatalogDataFile { statistics, .. } + | Self::DatabaseDataSource { statistics, .. } => statistics.as_ref(), } } pub fn get_partition_spec(&self) -> Option<&PartitionSpec> { match self { - Self::AnonymousDataFile { partition_spec, .. } => partition_spec.as_ref(), + Self::AnonymousDataFile { partition_spec, .. } + | Self::DatabaseDataSource { partition_spec, .. } => partition_spec.as_ref(), Self::CatalogDataFile { partition_spec, .. } => Some(partition_spec), } } @@ -246,6 +262,38 @@ impl DataFileSource { res.push(format!("Statistics = {}", statistics)); } } + Self::DatabaseDataSource { + path, + chunk_spec, + size_bytes, + metadata, + partition_spec, + statistics, + } => { + res.push(format!("Path = {}", path)); + if let Some(chunk_spec) = chunk_spec { + res.push(format!( + "Chunk spec = {{ {} }}", + chunk_spec.multiline_display().join(", ") + )); + } + if let Some(size_bytes) = size_bytes { + res.push(format!("Size bytes = {}", size_bytes)); + } + res.push(format!( + "Metadata = {}", + metadata.multiline_display().join(", ") + )); + if let Some(partition_spec) = partition_spec { + res.push(format!( + "Partition spec = {}", + partition_spec.multiline_display().join(", ") + )); + } + if let Some(statistics) = statistics { + res.push(format!("Statistics = {}", statistics)); + } + } } res } @@ -428,6 +476,7 @@ impl ScanTask { FileFormatConfig::Csv(_) | FileFormatConfig::Json(_) => { config.csv_inflation_factor } + FileFormatConfig::Database(_) => 0.0, }; // estimate number of rows from read schema diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index a7ea077573..548d671b5d 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -324,14 +324,18 @@ pub mod pylib { url: String, file_format: PyFileFormatConfig, schema: PySchema, + num_rows: i64, storage_config: PyStorageConfig, + size_bytes: Option, pushdowns: Option, ) -> PyResult { - let data_source = DataFileSource::AnonymousDataFile { + let data_source = DataFileSource::DatabaseDataSource { path: url, chunk_spec: None, - size_bytes: None, - metadata: None, + size_bytes, + metadata: TableMetadata { + length: num_rows as usize, + }, partition_spec: None, statistics: None, }; diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index fe3d470a2f..b8662219a9 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -184,7 +184,7 @@ pub fn split_by_row_groups( chunk_spec, size_bytes, .. - } => { + } | DataFileSource::DatabaseDataSource { chunk_spec, size_bytes, .. } => { *chunk_spec = Some(ChunkSpec::Parquet(curr_row_groups)); *size_bytes = Some(curr_size_bytes as u64); diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index a2864fac39..e9a6537490 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -25,20 +25,24 @@ "mysql+pymysql://username:password@localhost:3306/mysql", "sqlite:///", ] -NUM_TEST_ROWS = 200 -NUM_ROWS_PER_PARTITION = 50 TEST_TABLE_NAME = "example" -@pytest.fixture(scope="session") -def generated_data() -> pd.DataFrame: +@pytest.fixture(scope="session", params=[{"num_rows": 200}]) +def generated_data(request: pytest.FixtureRequest) -> pd.DataFrame: + num_rows = request.param["num_rows"] + num_rows_per_variety = num_rows // 4 + variety_arr = ( + ["setosa"] * num_rows_per_variety + ["versicolor"] * num_rows_per_variety + ["virginica"] * num_rows_per_variety + ) + data = { - "id": np.arange(NUM_TEST_ROWS), - "sepal_length": np.arange(NUM_TEST_ROWS, dtype=float), - "sepal_width": np.arange(NUM_TEST_ROWS, dtype=float), - "petal_length": np.arange(NUM_TEST_ROWS, dtype=float), - "petal_width": np.arange(NUM_TEST_ROWS, dtype=float), - "variety": ["setosa"] * 50 + ["versicolor"] * 50 + ["virginica"] * 50 + [None] * 50, + "id": np.arange(num_rows), + "sepal_length": np.arange(num_rows, dtype=float), + "sepal_width": np.arange(num_rows, dtype=float), + "petal_length": np.arange(num_rows, dtype=float), + "petal_width": np.arange(num_rows, dtype=float), + "variety": variety_arr + [None] * (num_rows - len(variety_arr)), } return pd.DataFrame(data) @@ -65,7 +69,7 @@ def setup_database(db_url: str, data: pd.DataFrame) -> None: # Ensure the table is created and populated with engine.connect() as conn: result = conn.execute(text(f"SELECT COUNT(*) FROM {TEST_TABLE_NAME}")).fetchone()[0] - assert result == NUM_TEST_ROWS + assert result == len(data) def create_and_populate(engine: Engine, data: pd.DataFrame) -> None: diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 84d814c9dd..0fe41738d3 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -1,14 +1,13 @@ from __future__ import annotations +import math + import pytest import daft +from daft.context import set_execution_config from tests.conftest import assert_df_equals -from tests.integration.sql.conftest import ( - NUM_ROWS_PER_PARTITION, - NUM_TEST_ROWS, - TEST_TABLE_NAME, -) +from tests.integration.sql.conftest import TEST_TABLE_NAME @pytest.mark.integration() @@ -19,18 +18,26 @@ def test_sql_create_dataframe_ok(test_db, generated_data) -> None: @pytest.mark.integration() -@pytest.mark.parametrize("num_partitions", [1, 2, 3]) -def test_sql_partitioned_read(test_db, num_partitions) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} LIMIT {NUM_ROWS_PER_PARTITION * num_partitions}", test_db) +@pytest.mark.parametrize("num_partitions", [2, 3, 4]) +def test_sql_partitioned_read(test_db, num_partitions, generated_data) -> None: + row_size_bytes = daft.from_pandas(generated_data).schema().estimate_row_size_bytes() + num_rows_per_partition = len(generated_data) // num_partitions + limit = num_rows_per_partition * num_partitions + set_execution_config(read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition)) + + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} LIMIT {limit}", test_db) assert df.num_partitions() == num_partitions df = df.collect() - assert len(df) == NUM_ROWS_PER_PARTITION * num_partitions + assert len(df) == limit - # test with a number of rows that is not a multiple of NUM_ROWS_PER_PARTITION - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} LIMIT {NUM_ROWS_PER_PARTITION * num_partitions + 1}", test_db) - assert df.num_partitions() == num_partitions + 1 + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [2, 3, 4]) +def test_sql_partitioned_read_with_custom_num_partitions(test_db, num_partitions, generated_data) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, num_partitions=num_partitions) + assert df.num_partitions() == num_partitions df = df.collect() - assert len(df) == NUM_ROWS_PER_PARTITION * num_partitions + 1 + assert len(df) == len(generated_data) @pytest.mark.integration() @@ -76,24 +83,24 @@ def test_sql_read_with_filter_pushdowns(test_db, column, operator, value, expect @pytest.mark.integration() -def test_sql_read_with_if_else_filter_pushdown(test_db) -> None: +def test_sql_read_with_if_else_filter_pushdown(test_db, generated_data) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) - df = df.where(df["variety"] == (df["sepal_length"] >= 125).if_else("virginica", "setosa")) + df = df.where(df["variety"] == (df["id"] == -1).if_else("virginica", "setosa")) df = df.collect() - assert len(df) == 75 + assert len(df) == len(generated_data) // 4 @pytest.mark.integration() def test_sql_read_with_all_pushdowns(test_db) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) - df = df.where(df["sepal_length"] > 100) + df = df.where(~(df["sepal_length"] > 1)) df = df.select(df["sepal_length"], df["variety"]) - df = df.limit(50) + df = df.limit(1) df = df.collect() assert df.column_names == ["sepal_length", "variety"] - assert len(df) == 50 + assert len(df) == 1 @pytest.mark.integration() @@ -107,13 +114,13 @@ def test_sql_read_with_limit_pushdown(test_db, limit) -> None: @pytest.mark.integration() -def test_sql_read_with_projection_pushdown(test_db) -> None: +def test_sql_read_with_projection_pushdown(test_db, generated_data) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) df = df.select(df["sepal_length"], df["variety"]) df = df.collect() assert df.column_names == ["sepal_length", "variety"] - assert len(df) == NUM_TEST_ROWS + assert len(df) == len(generated_data) @pytest.mark.integration() From eab61b0a65343f17761548d9495dd76571dacdc6 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 29 Feb 2024 10:23:10 -0800 Subject: [PATCH 17/30] fix math --- daft/sql/sql_scan.py | 13 ++++++++----- tests/integration/sql/test_sql.py | 7 +++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 48b3e74d1a..3cb60f80c3 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -75,7 +75,7 @@ def multiline_display(self) -> list[str]: def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: total_rows = self._get_num_rows() - estimate_row_size_bytes = math.ceil(self.schema().estimate_row_size_bytes()) + estimate_row_size_bytes = self.schema().estimate_row_size_bytes() total_size = total_rows * estimate_row_size_bytes if not self._limit_and_offset_supported: @@ -89,7 +89,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: schema=self._schema._schema, num_rows=total_rows, storage_config=self.storage_config, - size_bytes=total_size, + size_bytes=math.ceil(total_size), pushdowns=pushdowns, ) ] @@ -101,11 +101,14 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: else self._num_partitions ) num_rows_per_scan_task = total_rows // num_scan_tasks + num_scan_tasks_with_extra_row = total_rows % num_scan_tasks scan_tasks = [] offset = 0 - for _ in range(num_scan_tasks): - limit = max(num_rows_per_scan_task, total_rows - offset) + for i in range(num_scan_tasks): + limit = num_rows_per_scan_task + if i < num_scan_tasks_with_extra_row: + limit += 1 file_format_config = FileFormatConfig.from_database_config( DatabaseSourceConfig( self.sql, limit=limit, offset=offset, limit_before_offset=self._limit_before_offset @@ -118,7 +121,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: schema=self._schema._schema, num_rows=limit, storage_config=self.storage_config, - size_bytes=limit * estimate_row_size_bytes, + size_bytes=math.ceil(limit * estimate_row_size_bytes), pushdowns=pushdowns, ) ) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 0fe41738d3..f1c130a8c6 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -21,14 +21,13 @@ def test_sql_create_dataframe_ok(test_db, generated_data) -> None: @pytest.mark.parametrize("num_partitions", [2, 3, 4]) def test_sql_partitioned_read(test_db, num_partitions, generated_data) -> None: row_size_bytes = daft.from_pandas(generated_data).schema().estimate_row_size_bytes() - num_rows_per_partition = len(generated_data) // num_partitions - limit = num_rows_per_partition * num_partitions + num_rows_per_partition = len(generated_data) / num_partitions set_execution_config(read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition)) - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} LIMIT {limit}", test_db) + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) assert df.num_partitions() == num_partitions df = df.collect() - assert len(df) == limit + assert len(df) == len(generated_data) @pytest.mark.integration() From 7e86e43d4d4606b6f9c76fe0ff4c3815998ffeff Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 29 Feb 2024 16:28:07 -0800 Subject: [PATCH 18/30] to_sql_inner --- daft/io/_sql.py | 2 +- daft/sql/sql_reader.py | 5 +- daft/sql/sql_scan.py | 3 + src/daft-dsl/src/expr.rs | 159 +++++++++++++++++------------- tests/integration/sql/test_sql.py | 3 +- 5 files changed, 95 insertions(+), 77 deletions(-) diff --git a/daft/io/_sql.py b/daft/io/_sql.py index cf20f677e2..84c74202e8 100644 --- a/daft/io/_sql.py +++ b/daft/io/_sql.py @@ -35,7 +35,7 @@ def read_sql(sql: str, url: str, num_partitions: Optional[int] = None) -> DataFr handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) - if num_partitions is not None and num_partitions > 1 and not sql_operator._limit_and_offset_supported: + if num_partitions is not None and num_partitions > 1 and not sql_operator.can_partition_read(): return DataFrame(builder).into_partitions(num_partitions) else: return DataFrame(builder) diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index 8bb8b9b3ad..08e376fb9d 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -98,11 +98,8 @@ def _execute_sql_query_with_sqlalchemy(self, sql: str) -> pa.Table: try: with create_engine(self.url).connect() as connection: result = connection.execute(text(sql)) - cursor = result.cursor - - columns = [column_description[0] for column_description in cursor.description] rows = result.fetchall() - pydict = {column: [row[i] for row in rows] for i, column in enumerate(columns)} + pydict = {column_name: [row[i] for row in rows] for i, column_name in enumerate(result.keys())} return pa.Table.from_pydict(pydict) except Exception as e: diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 3cb60f80c3..f290245423 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -137,3 +137,6 @@ def can_absorb_limit(self) -> bool: def can_absorb_select(self) -> bool: return False + + def can_partition_read(self) -> bool: + return self._limit_and_offset_supported diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 69bc8acbb6..cbabf61ca8 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -17,6 +17,7 @@ use common_error::{DaftError, DaftResult}; use serde::{Deserialize, Serialize}; use std::{ fmt::{Debug, Display, Formatter, Result}, + io::{self, Write}, sync::Arc, }; @@ -570,80 +571,96 @@ impl Expr { } pub fn to_sql(&self) -> Option { - self.to_sql_inner() - } - - fn to_sql_inner(&self) -> Option { - match self { - Expr::Column(name) => Some(name.to_string()), - Expr::Literal(lit) => match lit { - lit::LiteralValue::Series(series) => match series.data_type() { - DataType::Utf8 => { - let s = lit.to_string(); - let trimmed = s.trim_matches(|c| c == '[' || c == ']').to_string(); - let quoted = trimmed - .split(", ") - .map(|s| format!("'{}'", s)) - .collect::>(); - Some(quoted.join(", ")) - } - _ => Some( - lit.to_string() - .trim_matches(|c| c == '[' || c == ']') - .to_string(), - ), + fn to_sql_inner(expr: &Expr, buffer: &mut W) -> io::Result<()> { + match expr { + Expr::Column(name) => write!(buffer, "{}", name), + Expr::Literal(lit) => match lit { + lit::LiteralValue::Series(series) => match series.data_type() { + DataType::Utf8 => { + let trimmed_and_quoted = lit + .to_string() + .trim_matches(|c| c == '[' || c == ']') + .split(", ") + .map(|s| format!("'{}'", s)) + .collect::>(); + write!(buffer, "{}", trimmed_and_quoted.join(", ")) + } + _ => write!( + buffer, + "{}", + lit.to_string().trim_matches(|c| c == '[' || c == ']') + ), + }, + _ => write!(buffer, "{}", lit.to_string().replace('\"', "'")), }, - _ => Some(lit.to_string().replace('\"', "'")), - }, - Expr::Alias(inner, ..) => inner.to_sql_inner(), - Expr::BinaryOp { op, left, right } => { - let left_sql = left.to_sql_inner()?; - let right_sql = right.to_sql_inner()?; - let op = match op { - Operator::Eq => "=".to_string(), - Operator::NotEq => "!=".to_string(), - Operator::Lt => "<".to_string(), - Operator::LtEq => "<=".to_string(), - Operator::Gt => ">".to_string(), - Operator::GtEq => ">=".to_string(), - Operator::And => "AND".to_string(), - Operator::Or => "OR".to_string(), - _ => return None, - }; - Some(format!("({}) {} ({})", left_sql, op, right_sql)) - } - Expr::Not(inner) => inner - .to_sql_inner() - .map(|inner_sql| format!("NOT ({})", inner_sql)), - Expr::IsNull(inner) => inner - .to_sql_inner() - .map(|inner_sql| format!("({}) IS NULL", inner_sql)), - Expr::NotNull(inner) => inner - .to_sql_inner() - .map(|inner_sql| format!("({}) IS NOT NULL", inner_sql)), - Expr::IfElse { - if_true, - if_false, - predicate, - } => { - let if_true_sql = if_true.to_sql_inner()?; - let if_false_sql = if_false.to_sql_inner()?; - let predicate_sql = predicate.to_sql_inner()?; - Some(format!( - "CASE WHEN {} THEN {} ELSE {} END", - predicate_sql, if_true_sql, if_false_sql - )) - } - Expr::IsIn(inner, items) => { - let inner_sql = inner.to_sql_inner()?; - let items_sql = items.to_sql_inner()?; - Some(format!("{} IN ({})", inner_sql, items_sql)) + Expr::Alias(inner, ..) => to_sql_inner(inner, buffer), + Expr::BinaryOp { op, left, right } => { + to_sql_inner(left, buffer)?; + let op = match op { + Operator::Eq => "=".to_string(), + Operator::NotEq => "!=".to_string(), + Operator::Lt => "<".to_string(), + Operator::LtEq => "<=".to_string(), + Operator::Gt => ">".to_string(), + Operator::GtEq => ">=".to_string(), + Operator::And => "AND".to_string(), + Operator::Or => "OR".to_string(), + _ => { + return Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported operator for SQL translation", + )) + } + }; + write!(buffer, " {} ", op)?; + to_sql_inner(right, buffer) + } + Expr::Not(inner) => { + write!(buffer, "NOT (")?; + to_sql_inner(inner, buffer)?; + write!(buffer, ")") + } + Expr::IsNull(inner) => { + write!(buffer, "(")?; + to_sql_inner(inner, buffer)?; + write!(buffer, ") IS NULL") + } + Expr::NotNull(inner) => { + write!(buffer, "(")?; + to_sql_inner(inner, buffer)?; + write!(buffer, ") IS NOT NULL") + } + Expr::IfElse { + if_true, + if_false, + predicate, + } => { + write!(buffer, "CASE WHEN ")?; + to_sql_inner(predicate, buffer)?; + write!(buffer, " THEN ")?; + to_sql_inner(if_true, buffer)?; + write!(buffer, " ELSE ")?; + to_sql_inner(if_false, buffer)?; + write!(buffer, " END") + } + Expr::IsIn(inner, items) => { + to_sql_inner(inner, buffer)?; + write!(buffer, " IN (")?; + to_sql_inner(items, buffer)?; + write!(buffer, ")") + } + // TODO: Implement SQL translations for these expressions if possible + Expr::Agg(..) | Expr::Cast(..) | Expr::Function { .. } => Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported expression for SQL translation", + )), } - // TODO: Implement SQL translations for these expressions if possible - Expr::Agg(..) => None, - Expr::Cast(..) => None, - Expr::Function { .. } => None, } + + let mut buffer = Vec::new(); + to_sql_inner(self, &mut buffer) + .ok() + .and_then(|_| String::from_utf8(buffer).ok()) } } diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index f1c130a8c6..7ac9f3b8c6 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -93,7 +93,8 @@ def test_sql_read_with_if_else_filter_pushdown(test_db, generated_data) -> None: @pytest.mark.integration() def test_sql_read_with_all_pushdowns(test_db) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) - df = df.where(~(df["sepal_length"] > 1)) + df = df.where(~(df["sepal_length"] < 1)) + df = df.where(df["variety"].is_in(["setosa", "versicolor"])) df = df.select(df["sepal_length"], df["variety"]) df = df.limit(1) From 2f48c55414e094a343cea30847f76690ee242fd9 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 1 Mar 2024 10:11:26 -0800 Subject: [PATCH 19/30] rename to apply_limit_before_offset --- daft/daft.pyi | 8 ++++++-- daft/runners/partitioning.py | 6 +++++- daft/sql/sql_reader.py | 10 +++++----- daft/sql/sql_scan.py | 10 +++++----- daft/table/table_io.py | 2 +- src/daft-micropartition/src/micropartition.rs | 4 ++-- src/daft-micropartition/src/python.rs | 4 ++-- src/daft-scan/src/file_format.rs | 17 ++++++++++------- 8 files changed, 36 insertions(+), 25 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index 443cda3911..25beb22673 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -225,10 +225,14 @@ class DatabaseSourceConfig: sql: str limit: int | None offset: int | None - limit_before_offset: bool | None + apply_limit_before_offset: bool | None def __init__( - self, sql: str, limit: int | None = None, offset: int | None = None, limit_before_offset: bool | None = None + self, + sql: str, + limit: int | None = None, + offset: int | None = None, + apply_limit_before_offset: bool | None = None, ): ... class FileFormatConfig: diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index fc41d59354..ebae740e6f 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -79,11 +79,15 @@ class TableReadSQLOptions: Args: limit: Number of rows to read, or None to read all rows offset: Number of rows to skip before reading + apply_limit_before_offset: Whether to apply the limit before the offset + + predicate_sql: SQL predicate to apply to the table + predicate_expression: Expression predicate to apply to the table """ limit: int | None = None offset: int | None = None - limit_before_offset: bool | None = None + apply_limit_before_offset: bool | None = None predicate_sql: str | None = None predicate_expression: Expression | None = None diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index 08e376fb9d..6f1f65a5a1 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -15,18 +15,18 @@ def __init__( url: str, limit: int | None = None, offset: int | None = None, - limit_before_offset: bool | None = None, + apply_limit_before_offset: bool | None = None, projection: list[str] | None = None, predicate: str | None = None, ) -> None: - if limit is not None and offset is not None and limit_before_offset is None: - raise ValueError("limit_before_offset must be specified when limit and offset are both specified") + if limit is not None and offset is not None and apply_limit_before_offset is None: + raise ValueError("apply_limit_before_offset must be specified when limit and offset are both specified") self.sql = sql self.url = url self.limit = limit self.offset = offset - self.limit_before_offset = limit_before_offset + self.apply_limit_before_offset = apply_limit_before_offset self.projection = projection self.predicate = predicate @@ -47,7 +47,7 @@ def _construct_sql_query(self) -> str: clauses.append(f" WHERE {self.predicate}") if self.limit is not None and self.offset is not None: - if self.limit_before_offset is True: + if self.apply_limit_before_offset is True: clauses.append(f" LIMIT {self.limit} OFFSET {self.offset}") else: clauses.append(f" OFFSET {self.offset} LIMIT {self.limit}") diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index f290245423..d3b707de7f 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -32,16 +32,16 @@ def __init__( self._initialize_schema_and_features() def _initialize_schema_and_features(self) -> None: - self._schema, self._limit_and_offset_supported, self._limit_before_offset = self._attempt_schema_read() + self._schema, self._limit_and_offset_supported, self._apply_limit_before_offset = self._attempt_schema_read() def _attempt_schema_read(self) -> tuple[Schema, bool, bool]: - for limit_before_offset in [True, False]: + for apply_limit_before_offset in [True, False]: try: pa_table = SQLReader( - self.sql, self.url, limit=1, offset=0, limit_before_offset=limit_before_offset + self.sql, self.url, limit=1, offset=0, apply_limit_before_offset=apply_limit_before_offset ).read() schema = Schema.from_pyarrow_schema(pa_table.schema) - return schema, True, limit_before_offset + return schema, True, apply_limit_before_offset except Exception: continue @@ -111,7 +111,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: limit += 1 file_format_config = FileFormatConfig.from_database_config( DatabaseSourceConfig( - self.sql, limit=limit, offset=offset, limit_before_offset=self._limit_before_offset + self.sql, limit=limit, offset=offset, apply_limit_before_offset=self._apply_limit_before_offset ) ) scan_tasks.append( diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 5f2207bc33..6a44165231 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -235,7 +235,7 @@ def read_sql( url, limit=sql_options.limit, offset=sql_options.offset, - limit_before_offset=sql_options.limit_before_offset, + apply_limit_before_offset=sql_options.apply_limit_before_offset, projection=read_options.column_names, predicate=sql_options.predicate_sql, ).read() diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index b4b008cc9f..b87b32bffb 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -312,7 +312,7 @@ fn materialize_scan_task( sql, limit, offset, - limit_before_offset, + apply_limit_before_offset, }) => { let predicate_sql = scan_task .pushdowns @@ -332,7 +332,7 @@ fn materialize_scan_task( url, *limit, *offset, - *limit_before_offset, + *apply_limit_before_offset, predicate_sql.clone(), predicate_expr.clone(), scan_task.schema.clone().into(), diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index b486b35af9..8fa458aa50 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -794,7 +794,7 @@ pub(crate) fn read_sql_into_py_table( url: &str, limit: Option, offset: Option, - limit_before_offset: Option, + apply_limit_before_offset: Option, predicate_sql: Option, predicate_expr: Option, schema: PySchema, @@ -821,7 +821,7 @@ pub(crate) fn read_sql_into_py_table( .call1(( limit, offset, - limit_before_offset, + apply_limit_before_offset, predicate_sql, predicate_pyexpr, ))?; diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index 440d8e55be..88e9b0a891 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -265,7 +265,7 @@ pub struct DatabaseSourceConfig { pub sql: String, pub limit: Option, pub offset: Option, - pub limit_before_offset: Option, + pub apply_limit_before_offset: Option, } impl DatabaseSourceConfig { @@ -273,13 +273,13 @@ impl DatabaseSourceConfig { sql: String, limit: Option, offset: Option, - limit_before_offset: Option, + apply_limit_before_offset: Option, ) -> Self { Self { sql, limit, offset, - limit_before_offset, + apply_limit_before_offset, } } @@ -292,8 +292,11 @@ impl DatabaseSourceConfig { if let Some(offset) = self.offset { res.push(format!("Offset = {}", offset)); } - if let Some(limit_before_offset) = self.limit_before_offset { - res.push(format!("Limit before offset = {}", limit_before_offset)); + if let Some(apply_limit_before_offset) = self.apply_limit_before_offset { + res.push(format!( + "Limit before offset = {}", + apply_limit_before_offset + )); } res } @@ -308,9 +311,9 @@ impl DatabaseSourceConfig { sql: &str, limit: Option, offset: Option, - limit_before_offset: Option, + apply_limit_before_offset: Option, ) -> Self { - Self::new_internal(sql.to_string(), limit, offset, limit_before_offset) + Self::new_internal(sql.to_string(), limit, offset, apply_limit_before_offset) } } From ff389d334124835dc51105cce8bd325c1338fb0b Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 1 Mar 2024 10:56:41 -0800 Subject: [PATCH 20/30] docs --- docs/source/api_docs/creation.rst | 52 +++++++++++++++++-------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/docs/source/api_docs/creation.rst b/docs/source/api_docs/creation.rst index d83fa84fcd..ad5b76ec96 100644 --- a/docs/source/api_docs/creation.rst +++ b/docs/source/api_docs/creation.rst @@ -20,6 +20,24 @@ Python Objects from_pylist from_pydict +Arrow +~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/io_functions + + from_arrow + +Pandas +~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/io_functions + + from_pandas + Files ----- @@ -54,6 +72,15 @@ JSON read_json +File Paths +~~~~~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/io_functions + + from_glob_path + Data Catalogs ------------- @@ -85,30 +112,7 @@ Arrow :nosignatures: :toctree: doc_gen/io_functions - -.. autosummary:: - :nosignatures: - :toctree: doc_gen/io_functions - - from_arrow - -Pandas -~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc_gen/io_functions - - from_pandas - -File Paths -~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc_gen/io_functions - - from_glob_path + read_sql Integrations ------------ From ccf1c4b7f1b27b53a0ebd5f531fcda464116ac42 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 4 Mar 2024 13:22:39 -0800 Subject: [PATCH 21/30] improve literal to_sql, use more equality tests, and add todos --- daft/sql/sql_reader.py | 8 +- src/daft-dsl/src/expr.rs | 45 +++---- src/daft-dsl/src/lit.rs | 23 ++++ src/daft-micropartition/src/micropartition.rs | 12 +- tests/integration/sql/conftest.py | 46 ++++---- tests/integration/sql/test_sql.py | 111 +++++++++++------- 6 files changed, 140 insertions(+), 105 deletions(-) diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index 6f1f65a5a1..b485cffb42 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -48,13 +48,13 @@ def _construct_sql_query(self) -> str: if self.limit is not None and self.offset is not None: if self.apply_limit_before_offset is True: - clauses.append(f" LIMIT {self.limit} OFFSET {self.offset}") + clauses.append(f"ORDER BY 1 LIMIT {self.limit} OFFSET {self.offset}") else: - clauses.append(f" OFFSET {self.offset} LIMIT {self.limit}") + clauses.append(f"ORDER BY 1 OFFSET {self.offset} LIMIT {self.limit}") elif self.limit is not None: - clauses.append(f" LIMIT {self.limit}") + clauses.append(f"ORDER BY 1 LIMIT {self.limit}") elif self.offset is not None: - clauses.append(f" OFFSET {self.offset}") + clauses.append(f"ORDER BY 1 OFFSET {self.offset}") return "\n".join(clauses) diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index cbabf61ca8..2ebdb02424 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -574,25 +574,16 @@ impl Expr { fn to_sql_inner(expr: &Expr, buffer: &mut W) -> io::Result<()> { match expr { Expr::Column(name) => write!(buffer, "{}", name), - Expr::Literal(lit) => match lit { - lit::LiteralValue::Series(series) => match series.data_type() { - DataType::Utf8 => { - let trimmed_and_quoted = lit - .to_string() - .trim_matches(|c| c == '[' || c == ']') - .split(", ") - .map(|s| format!("'{}'", s)) - .collect::>(); - write!(buffer, "{}", trimmed_and_quoted.join(", ")) - } - _ => write!( - buffer, - "{}", - lit.to_string().trim_matches(|c| c == '[' || c == ']') - ), - }, - _ => write!(buffer, "{}", lit.to_string().replace('\"', "'")), - }, + Expr::Literal(lit) => { + if let Some(s) = lit.to_sql() { + write!(buffer, "{}", s) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported literal for SQL translation", + )) + } + } Expr::Alias(inner, ..) => to_sql_inner(inner, buffer), Expr::BinaryOp { op, left, right } => { to_sql_inner(left, buffer)?; @@ -643,17 +634,13 @@ impl Expr { to_sql_inner(if_false, buffer)?; write!(buffer, " END") } - Expr::IsIn(inner, items) => { - to_sql_inner(inner, buffer)?; - write!(buffer, " IN (")?; - to_sql_inner(items, buffer)?; - write!(buffer, ")") - } // TODO: Implement SQL translations for these expressions if possible - Expr::Agg(..) | Expr::Cast(..) | Expr::Function { .. } => Err(io::Error::new( - io::ErrorKind::Other, - "Unsupported expression for SQL translation", - )), + Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) | Expr::Function { .. } => { + Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported expression for SQL translation", + )) + } } } diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 93be85dacb..25ffbf3e15 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -212,6 +212,29 @@ impl LiteralValue { }; result } + + pub fn to_sql(&self) -> Option { + use LiteralValue::*; + match self { + Null | Boolean(..) | Int32(..) | UInt32(..) | Int64(..) | UInt64(..) | Float64(..) => { + self.to_string().into() + } + Utf8(val) => format!("'{}'", val).into(), + Binary(val) => format!("x'{}'", val.len()).into(), + Date(val) => format!("DATE '{}'", display_date32(*val)).into(), + // TODO(Colin): Reading time from Postgres is parsed as Time(Nanoseconds), while from MySQL it is parsed as Duration(Microseconds) + // Need to fix our time comparison code to handle this. + Time(..) => None, + Timestamp(val, tu, tz) => format!( + "TIMESTAMP '{}'", + display_timestamp(*val, tu, tz).replace('T', " ") + ) + .into(), + Series(..) => None, + #[cfg(feature = "python")] + Python(..) => None, + } + } } pub trait Literal { diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index b87b32bffb..5aae75acef 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -314,11 +314,13 @@ fn materialize_scan_task( offset, apply_limit_before_offset, }) => { - let predicate_sql = scan_task - .pushdowns - .filters - .as_ref() - .and_then(|p| p.to_sql()); + // TODO(Colin): Do more rigorous testing for predicate pushdowns into sql + // let predicate_sql = scan_task + // .pushdowns + // .filters + // .as_ref() + // .and_then(|p| p.to_sql()); + let predicate_sql = None; let predicate_expr = scan_task .pushdowns .filters diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index e9a6537490..af7018e224 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -import tempfile +from datetime import date, datetime, time, timedelta from typing import Generator import numpy as np @@ -8,7 +8,9 @@ import pytest import tenacity from sqlalchemy import ( + Boolean, Column, + Date, Engine, Float, Integer, @@ -23,7 +25,6 @@ "trino://user@localhost:8080/memory/default", "postgresql://username:password@localhost:5432/postgres", "mysql+pymysql://username:password@localhost:3306/mysql", - "sqlite:///", ] TEST_TABLE_NAME = "example" @@ -31,18 +32,20 @@ @pytest.fixture(scope="session", params=[{"num_rows": 200}]) def generated_data(request: pytest.FixtureRequest) -> pd.DataFrame: num_rows = request.param["num_rows"] - num_rows_per_variety = num_rows // 4 - variety_arr = ( - ["setosa"] * num_rows_per_variety + ["versicolor"] * num_rows_per_variety + ["virginica"] * num_rows_per_variety - ) data = { "id": np.arange(num_rows), - "sepal_length": np.arange(num_rows, dtype=float), - "sepal_width": np.arange(num_rows, dtype=float), - "petal_length": np.arange(num_rows, dtype=float), - "petal_width": np.arange(num_rows, dtype=float), - "variety": variety_arr + [None] * (num_rows - len(variety_arr)), + "float_col": np.arange(num_rows, dtype=float), + "string_col": [f"row_{i}" for i in range(num_rows)], + "bool_col": [True for _ in range(num_rows // 2)] + [False for _ in range(num_rows // 2)], + "date_col": [date(2021, 1, 1) + timedelta(days=i) for i in range(num_rows)], + # TODO(Colin): ConnectorX parses datetime as pyarrow date64 type, which we currently cast to Python, causing our assertions to fail. + # One possible solution is to cast date64 into Timestamp("ms") in our from_arrow code. + # "date_time_col": [datetime(2020, 1, 1, 10, 0, 0) + timedelta(hours=i) for i in range(num_rows)], + "time_col": [ + (datetime.combine(datetime.today(), time(0, 0)) + timedelta(minutes=x)).time() for x in range(200) + ], + "null_col": [None if i % 2 == 1 else f"not_null_{i}" for i in range(num_rows)], } return pd.DataFrame(data) @@ -50,15 +53,8 @@ def generated_data(request: pytest.FixtureRequest) -> pd.DataFrame: @pytest.fixture(scope="session", params=URLS) def test_db(request: pytest.FixtureRequest, generated_data: pd.DataFrame) -> Generator[str, None, None]: db_url = request.param - if db_url.startswith("sqlite"): - # No docker container for sqlite, so we need to create a temporary file - with tempfile.NamedTemporaryFile(suffix=".db") as file: - db_url += file.name - setup_database(db_url, generated_data) - yield db_url - else: - setup_database(db_url, generated_data) - yield db_url + setup_database(db_url, generated_data) + yield db_url @tenacity.retry(stop=tenacity.stop_after_delay(10), wait=tenacity.wait_fixed(5), reraise=True) @@ -78,11 +74,11 @@ def create_and_populate(engine: Engine, data: pd.DataFrame) -> None: TEST_TABLE_NAME, metadata, Column("id", Integer), - Column("sepal_length", Float), - Column("sepal_width", Float), - Column("petal_length", Float), - Column("petal_width", Float), - Column("variety", String(50)), + Column("float_col", Float), + Column("string_col", String(50)), + Column("bool_col", Boolean), + Column("date_col", Date), + Column("null_col", String(50)), ) metadata.create_all(engine) data.to_sql(table.name, con=engine, if_exists="replace", index=False) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 7ac9f3b8c6..13759f63e6 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -1,7 +1,9 @@ from __future__ import annotations +import datetime import math +import pandas as pd import pytest import daft @@ -11,95 +13,120 @@ @pytest.mark.integration() -def test_sql_create_dataframe_ok(test_db, generated_data) -> None: +def test_sql_create_dataframe_ok(test_db) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) - assert_df_equals(df.to_pandas(), generated_data, sort_key="id") + assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [2, 3, 4]) -def test_sql_partitioned_read(test_db, num_partitions, generated_data) -> None: - row_size_bytes = daft.from_pandas(generated_data).schema().estimate_row_size_bytes() - num_rows_per_partition = len(generated_data) / num_partitions +def test_sql_partitioned_read(test_db, num_partitions) -> None: + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + + row_size_bytes = daft.from_pandas(pdf).schema().estimate_row_size_bytes() + num_rows_per_partition = len(pdf) / num_partitions set_execution_config(read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition)) - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} ORDER BY id", test_db) assert df.num_partitions() == num_partitions - df = df.collect() - assert len(df) == len(generated_data) + assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [2, 3, 4]) -def test_sql_partitioned_read_with_custom_num_partitions(test_db, num_partitions, generated_data) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, num_partitions=num_partitions) +def test_sql_partitioned_read_with_custom_num_partitions(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} ORDER BY id", test_db, num_partitions=num_partitions) assert df.num_partitions() == num_partitions - df = df.collect() - assert len(df) == len(generated_data) + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() @pytest.mark.parametrize( - "column,operator,value,expected_length", + "operator", + ["<", ">", "=", "!=", ">=", "<="], +) +@pytest.mark.parametrize( + "column, value", [ - ("id", ">", 100, 99), - ("id", "<=", 100, 101), - ("sepal_length", "<", 100.0, 100), - ("sepal_length", ">=", 100.0, 100), - ("variety", "=", "setosa", 50), - ("variety", "!=", "setosa", 100), - ("variety", "is_null", None, 50), - ("variety", "not_null", None, 150), - ("variety", "is_in", ["setosa", "versicolor"], 100), - ("id", "is_in", [1, 2, 3], 3), + ("id", 100), + ("float_col", 100.0), + ("string_col", "row_100"), + ("bool_col", True), + ("date_col", datetime.date(2021, 1, 1)), + # TODO(Colin) - ConnectorX parses datetime as pyarrow date64 type, which we currently cast to Python, causing our assertions to fail. + # One possible solution is to cast date64 into Timestamp("ms") in our from_arrow code. + # ("date_time_col", datetime.datetime(2020, 1, 1, 10, 0, 0)), + # TODO(Colin) - Reading time from Postgres is parsed as Time(Nanoseconds), while from MySQL it is parsed as Duration(Microseconds) + # Need to fix our time comparison code to handle this. + # ("time_col", datetime.time(10, 0, 0)), ], ) -def test_sql_read_with_filter_pushdowns(test_db, column, operator, value, expected_length) -> None: +def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) if operator == ">": df = df.where(df[column] > value) + pdf = pdf[pdf[column] > value] elif operator == "<": df = df.where(df[column] < value) + pdf = pdf[pdf[column] < value] elif operator == "=": df = df.where(df[column] == value) + pdf = pdf[pdf[column] == value] elif operator == "!=": df = df.where(df[column] != value) + pdf = pdf[pdf[column] != value] elif operator == ">=": df = df.where(df[column] >= value) + pdf = pdf[pdf[column] >= value] elif operator == "<=": df = df.where(df[column] <= value) - elif operator == "is_null": - df = df.where(df[column].is_null()) - elif operator == "not_null": - df = df.where(df[column].not_null()) - elif operator == "is_in": - df = df.where(df[column].is_in(value)) + pdf = pdf[pdf[column] <= value] - df = df.collect() - assert len(df) == expected_length + assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() -def test_sql_read_with_if_else_filter_pushdown(test_db, generated_data) -> None: +def test_sql_read_with_is_null_filter_pushdowns(test_db) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) - df = df.where(df["variety"] == (df["id"] == -1).if_else("virginica", "setosa")) + df = df.where(df["null_col"].is_null()) - df = df.collect() - assert len(df) == len(generated_data) // 4 + pydict = df.to_pydict() + assert all(value is None for value in pydict["null_col"]) + + +@pytest.mark.integration() +def test_sql_read_with_not_null_filter_pushdowns(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df = df.where(df["null_col"].not_null()) + + pydict = df.to_pydict() + assert all(value is not None for value in pydict["null_col"]) + + +@pytest.mark.integration() +def test_sql_read_with_if_else_filter_pushdown(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50)) + + pydict = df.to_pydict() + assert all(value < 50 or value > 150 for value in pydict["float_col"]) @pytest.mark.integration() def test_sql_read_with_all_pushdowns(test_db) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) - df = df.where(~(df["sepal_length"] < 1)) - df = df.where(df["variety"].is_in(["setosa", "versicolor"])) - df = df.select(df["sepal_length"], df["variety"]) + df = df.where(~(df["id"] < 1)) + df = df.where(df["string_col"].is_in([f"row_{i}" for i in range(10)])) + df = df.select(df["id"], df["string_col"]) df = df.limit(1) df = df.collect() - assert df.column_names == ["sepal_length", "variety"] + assert df.column_names == ["id", "string_col"] assert len(df) == 1 @@ -116,10 +143,10 @@ def test_sql_read_with_limit_pushdown(test_db, limit) -> None: @pytest.mark.integration() def test_sql_read_with_projection_pushdown(test_db, generated_data) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) - df = df.select(df["sepal_length"], df["variety"]) + df = df.select(df["id"], df["string_col"]) df = df.collect() - assert df.column_names == ["sepal_length", "variety"] + assert df.column_names == ["id", "string_col"] assert len(df) == len(generated_data) From ff41a7833e4d247617a02cf3b4ef035102985117 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 4 Mar 2024 13:34:00 -0800 Subject: [PATCH 22/30] fix stuff from merge conflict --- docs/source/api_docs/creation.rst | 20 +++++++++----------- src/daft-dsl/src/lit.rs | 6 ++---- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/docs/source/api_docs/creation.rst b/docs/source/api_docs/creation.rst index ad5b76ec96..b02b3ee3a0 100644 --- a/docs/source/api_docs/creation.rst +++ b/docs/source/api_docs/creation.rst @@ -102,17 +102,6 @@ Delta Lake read_delta_lake -In-memory ---------- - -Arrow -~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc_gen/io_functions - - read_sql Integrations ------------ @@ -136,3 +125,12 @@ Dask :toctree: doc_gen/io_functions from_dask_dataframe + +Databases +~~~~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/io_functions + + read_sql \ No newline at end of file diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 25ffbf3e15..983029e80c 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -222,15 +222,13 @@ impl LiteralValue { Utf8(val) => format!("'{}'", val).into(), Binary(val) => format!("x'{}'", val.len()).into(), Date(val) => format!("DATE '{}'", display_date32(*val)).into(), - // TODO(Colin): Reading time from Postgres is parsed as Time(Nanoseconds), while from MySQL it is parsed as Duration(Microseconds) - // Need to fix our time comparison code to handle this. - Time(..) => None, Timestamp(val, tu, tz) => format!( "TIMESTAMP '{}'", display_timestamp(*val, tu, tz).replace('T', " ") ) .into(), - Series(..) => None, + // TODO(Colin): Implement the rest of the types in future work for SQl pushdowns. + Decimal(..) | Series(..) | Time(..) => None, #[cfg(feature = "python")] Python(..) => None, } From b077bc58309a2033ea6377706a5b5af010a633c0 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 4 Mar 2024 13:59:51 -0800 Subject: [PATCH 23/30] disable pushdowns in sql reader --- daft/table/table_io.py | 5 +++-- src/daft-micropartition/src/micropartition.rs | 12 +++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 6a44165231..fd74407560 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -236,8 +236,9 @@ def read_sql( limit=sql_options.limit, offset=sql_options.offset, apply_limit_before_offset=sql_options.apply_limit_before_offset, - projection=read_options.column_names, - predicate=sql_options.predicate_sql, + # TODO(Colin): Enable pushdowns + projection=None, + predicate=None, ).read() mp = MicroPartition.from_arrow(pa_table) diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 5aae75acef..b87b32bffb 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -314,13 +314,11 @@ fn materialize_scan_task( offset, apply_limit_before_offset, }) => { - // TODO(Colin): Do more rigorous testing for predicate pushdowns into sql - // let predicate_sql = scan_task - // .pushdowns - // .filters - // .as_ref() - // .and_then(|p| p.to_sql()); - let predicate_sql = None; + let predicate_sql = scan_task + .pushdowns + .filters + .as_ref() + .and_then(|p| p.to_sql()); let predicate_expr = scan_task .pushdowns .filters From 3c49a8ae04b5c3376b388860341009069643d33c Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 4 Mar 2024 14:03:42 -0800 Subject: [PATCH 24/30] disable pushdowns in sql reader --- daft/table/table_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index fd74407560..291450a25a 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -242,7 +242,7 @@ def read_sql( ).read() mp = MicroPartition.from_arrow(pa_table) - if sql_options.predicate_sql is None and sql_options.predicate_expression is not None: + if sql_options.predicate_expression is not None: mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression])) if read_options.num_rows is not None: From f7ec4c91d8454ffc68f18ce102a736b7591796e8 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 5 Mar 2024 17:48:09 -0800 Subject: [PATCH 25/30] revise partioning algo --- daft/daft.pyi | 14 +- daft/datatype.py | 3 + daft/expressions/expressions.py | 3 + daft/io/_sql.py | 30 +++- daft/runners/partitioning.py | 10 +- daft/sql/sql_reader.py | 41 +++-- daft/sql/sql_scan.py | 151 ++++++++++++------ daft/table/table_io.py | 5 +- src/daft-core/src/python/datatype.rs | 4 + src/daft-dsl/src/lit.rs | 2 +- src/daft-dsl/src/python.rs | 4 + src/daft-micropartition/src/micropartition.rs | 10 +- src/daft-micropartition/src/python.rs | 13 +- src/daft-scan/src/file_format.rs | 42 ++--- src/daft-scan/src/lib.rs | 49 ++---- src/daft-scan/src/python.rs | 6 +- tests/integration/sql/conftest.py | 4 +- tests/integration/sql/test_sql.py | 106 ++++++++---- 18 files changed, 288 insertions(+), 209 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index 25beb22673..adfd063f14 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -223,16 +223,14 @@ class DatabaseSourceConfig: """ sql: str - limit: int | None - offset: int | None - apply_limit_before_offset: bool | None + left_bound: str | None + right_bound: str | None def __init__( self, sql: str, - limit: int | None = None, - offset: int | None = None, - apply_limit_before_offset: bool | None = None, + left_bound: str | None = None, + right_bound: str | None = None, ): ... class FileFormatConfig: @@ -612,7 +610,7 @@ class ScanTask: url: str, file_format: FileFormatConfig, schema: PySchema, - num_rows: int, + num_rows: int | None, storage_config: StorageConfig, size_bytes: int | None, pushdowns: Pushdowns | None, @@ -838,6 +836,7 @@ class PyDataType: @staticmethod def python() -> PyDataType: ... def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pyarrow.DataType: ... + def is_numeric(self) -> builtins.bool: ... def is_image(self) -> builtins.bool: ... def is_fixed_shape_image(self) -> builtins.bool: ... def is_tensor(self) -> builtins.bool: ... @@ -913,6 +912,7 @@ class PyExpr: def not_null(self) -> PyExpr: ... def is_in(self, other: PyExpr) -> PyExpr: ... def name(self) -> str: ... + def to_sql(self) -> str | None: ... def to_field(self, schema: PySchema) -> PyField: ... def __repr__(self) -> str: ... def __hash__(self) -> int: ... diff --git a/daft/datatype.py b/daft/datatype.py index 6c0cf6c1d4..81114698c4 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -479,6 +479,9 @@ def _is_image_type(self) -> builtins.bool: def _is_fixed_shape_image_type(self) -> builtins.bool: return self._dtype.is_fixed_shape_image() + def _is_numeric_type(self) -> builtins.bool: + return self._dtype.is_numeric() + def _is_map(self) -> builtins.bool: return self._dtype.is_map() diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 89e4ca8728..09b3ad6fc6 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -473,6 +473,9 @@ def name(self) -> builtins.str: def __repr__(self) -> builtins.str: return repr(self._expr) + def _to_sql(self) -> builtins.str | None: + return self._expr.to_sql() + def _to_field(self, schema: Schema) -> Field: return Field._from_pyfield(self._expr.to_field(schema._schema)) diff --git a/daft/io/_sql.py b/daft/io/_sql.py index 84c74202e8..c813f5d9fe 100644 --- a/daft/io/_sql.py +++ b/daft/io/_sql.py @@ -12,30 +12,44 @@ @PublicAPI -def read_sql(sql: str, url: str, num_partitions: Optional[int] = None) -> DataFrame: - """Creates a DataFrame from a SQL query +def read_sql( + sql: str, url: str, partition_col: Optional[str] = None, num_partitions: Optional[int] = None +) -> DataFrame: + """Creates a DataFrame from a SQL query. Example: >>> df = daft.read_sql("SELECT * FROM my_table", "sqlite:///my_database.db") + .. NOTE:: + If partition_col is specified, this function will partition the query by the specified column. You may specify the number of partitions, or let Daft determine the number of partitions. + Daft will attempt to partition the query on the percentiles of the specified column, and will attempt to balance the number of rows in each partition. + If the database does not support the necessary SQL syntax to partition the query, Daft will partition the query via ranges between the min and max values of the specified column. + Args: sql (str): SQL query to execute url (str): URL to the database + partition_col (Optional[str]): Column to partition the data by, defaults to None num_partitions (Optional[int]): Number of partitions to read the data into, defaults to None, which will lets Daft determine the number of partitions. - returns: + Returns: DataFrame: Dataframe containing the results of the query """ + if num_partitions is not None and partition_col is None: + raise ValueError("Failed to execute sql: partition_col must be specified when num_partitions is specified") + io_config = context.get_context().daft_planning_config.default_io_config storage_config = StorageConfig.python(PythonStorageConfig(io_config)) - sql_operator = SQLScanOperator(sql, url, storage_config=storage_config, num_partitions=num_partitions) + sql_operator = SQLScanOperator( + sql, + url, + storage_config, + partition_col=partition_col, + num_partitions=num_partitions, + ) handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) - if num_partitions is not None and num_partitions > 1 and not sql_operator.can_partition_read(): - return DataFrame(builder).into_partitions(num_partitions) - else: - return DataFrame(builder) + return DataFrame(builder) diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index ebae740e6f..87445fde89 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -77,17 +77,15 @@ class TableReadSQLOptions: """Options for parsing SQL tables Args: - limit: Number of rows to read, or None to read all rows - offset: Number of rows to skip before reading - apply_limit_before_offset: Whether to apply the limit before the offset + left_bound: Lower bound of the table to read + right_bound: Upper bound of the table to read predicate_sql: SQL predicate to apply to the table predicate_expression: Expression predicate to apply to the table """ - limit: int | None = None - offset: int | None = None - apply_limit_before_offset: bool | None = None + left_bound: str | None = None + right_bound: str | None = None predicate_sql: str | None = None predicate_expression: Expression | None = None diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index b485cffb42..565ac91e2e 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -13,20 +13,18 @@ def __init__( self, sql: str, url: str, + left_bound: str | None = None, + right_bound: str | None = None, limit: int | None = None, - offset: int | None = None, - apply_limit_before_offset: bool | None = None, projection: list[str] | None = None, predicate: str | None = None, ) -> None: - if limit is not None and offset is not None and apply_limit_before_offset is None: - raise ValueError("apply_limit_before_offset must be specified when limit and offset are both specified") self.sql = sql self.url = url + self.left_bound = left_bound + self.right_bound = right_bound self.limit = limit - self.offset = offset - self.apply_limit_before_offset = apply_limit_before_offset self.projection = projection self.predicate = predicate @@ -35,26 +33,27 @@ def read(self) -> pa.Table: return self._execute_sql_query(sql) def _construct_sql_query(self) -> str: - clauses = ["SELECT"] + base_query = f"SELECT * FROM ({self.sql}) AS subquery" + if self.left_bound is not None and self.right_bound is not None: + base_query = f"{base_query} WHERE {self.left_bound} AND {self.right_bound}" + elif self.left_bound is not None: + base_query = f"{base_query} WHERE {self.left_bound}" + elif self.right_bound is not None: + base_query = f"{base_query} WHERE {self.right_bound}" + + clauses = [] if self.projection is not None: - clauses.append(", ".join(self.projection)) + clauses.append(f"SELECT {', '.join(self.projection)}") else: - clauses.append("*") + clauses.append("SELECT *") - clauses.append(f"FROM ({self.sql}) AS subquery") + clauses.append(f"FROM ({base_query}) AS subquery") if self.predicate is not None: - clauses.append(f" WHERE {self.predicate}") - - if self.limit is not None and self.offset is not None: - if self.apply_limit_before_offset is True: - clauses.append(f"ORDER BY 1 LIMIT {self.limit} OFFSET {self.offset}") - else: - clauses.append(f"ORDER BY 1 OFFSET {self.offset} LIMIT {self.limit}") - elif self.limit is not None: - clauses.append(f"ORDER BY 1 LIMIT {self.limit}") - elif self.offset is not None: - clauses.append(f"ORDER BY 1 OFFSET {self.offset}") + clauses.append(f"WHERE {self.predicate}") + + if self.limit is not None: + clauses.append(f"LIMIT {self.limit}") return "\n".join(clauses) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index d3b707de7f..f95e2c4488 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -1,7 +1,10 @@ from __future__ import annotations +import logging import math +import warnings from collections.abc import Iterator +from typing import Any from daft.context import get_context from daft.daft import ( @@ -11,10 +14,13 @@ ScanTask, StorageConfig, ) +from daft.expressions.expressions import lit from daft.io.scan import PartitionField, ScanOperator from daft.logical.schema import Schema from daft.sql.sql_reader import SQLReader +logger = logging.getLogger(__name__) + class SQLScanOperator(ScanOperator): def __init__( @@ -22,33 +28,27 @@ def __init__( sql: str, url: str, storage_config: StorageConfig, + partition_col: str | None = None, num_partitions: int | None = None, ) -> None: super().__init__() self.sql = sql self.url = url self.storage_config = storage_config + self._partition_col = partition_col self._num_partitions = num_partitions - self._initialize_schema_and_features() - - def _initialize_schema_and_features(self) -> None: - self._schema, self._limit_and_offset_supported, self._apply_limit_before_offset = self._attempt_schema_read() - - def _attempt_schema_read(self) -> tuple[Schema, bool, bool]: - for apply_limit_before_offset in [True, False]: - try: - pa_table = SQLReader( - self.sql, self.url, limit=1, offset=0, apply_limit_before_offset=apply_limit_before_offset - ).read() - schema = Schema.from_pyarrow_schema(pa_table.schema) - return schema, True, apply_limit_before_offset - except Exception: - continue - - # If both attempts fail, read without limit and offset - pa_table = SQLReader(self.sql, self.url).read() - schema = Schema.from_pyarrow_schema(pa_table.schema) - return schema, False, False + self._schema = self._attempt_schema_read() + + def _attempt_schema_read(self) -> Schema: + try: + pa_table = SQLReader(self.sql, self.url, limit=1).read() + schema = Schema.from_pyarrow_schema(pa_table.schema) + return schema + except Exception: + # If both attempts fail, read without limit and offset + pa_table = SQLReader(self.sql, self.url).read() + schema = Schema.from_pyarrow_schema(pa_table.schema) + return schema def _get_num_rows(self) -> int: pa_table = SQLReader( @@ -73,59 +73,111 @@ def multiline_display(self) -> list[str]: f"Schema = {self._schema}", ] + def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], str]: + if self._partition_col is None: + raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.") + + if not ( + self._schema[self._partition_col].dtype._is_temporal_type() + or self._schema[self._partition_col].dtype._is_numeric_type() + ): + raise ValueError( + f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning." + ) + + try: + # try to get percentiles using percentile_cont + percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)] + pa_table = SQLReader( + self.sql, + self.url, + projection=[ + f"percentile_cont({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" + for i, percentile in enumerate(percentiles) + ], + ).read() + bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)] + return bounds, "percentile" + + except Exception as e: + # if the above fails, use the min and max of the partition column + logger.info("Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", e) + try: + pa_table = SQLReader( + self.sql, + self.url, + projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"], + ).read() + min_val = pa_table.column(0)[0].as_py() + max_val = pa_table.column(1)[0].as_py() + return [min_val + (max_val - min_val) * i / num_scan_tasks for i in range(1, num_scan_tasks)], "min_max" + + except Exception: + raise ValueError( + f"Failed to get partition bounds from {self._partition_col}. Please ensure that the column exists, and is numeric or temporal." + ) + + def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) + return iter( + [ + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + num_rows=total_rows, + storage_config=self.storage_config, + size_bytes=math.ceil(total_size), + pushdowns=pushdowns, + ) + ] + ) + def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: total_rows = self._get_num_rows() estimate_row_size_bytes = self.schema().estimate_row_size_bytes() total_size = total_rows * estimate_row_size_bytes - - if not self._limit_and_offset_supported: - # If limit and offset are not supported, then we can't parallelize the scan, so we just return a single scan task - file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) - return iter( - [ - ScanTask.sql_scan_task( - url=self.url, - file_format=file_format_config, - schema=self._schema._schema, - num_rows=total_rows, - storage_config=self.storage_config, - size_bytes=math.ceil(total_size), - pushdowns=pushdowns, - ) - ] - ) - num_scan_tasks = ( math.ceil(total_size / get_context().daft_execution_config.read_sql_partition_size_bytes) if self._num_partitions is None else self._num_partitions ) - num_rows_per_scan_task = total_rows // num_scan_tasks - num_scan_tasks_with_extra_row = total_rows % num_scan_tasks + if num_scan_tasks == 1 or self._partition_col is None: + return self._single_scan_task(pushdowns, total_rows, total_size) + + partition_bounds, strategy = self._get_partition_bounds_and_strategy(num_scan_tasks) + partition_bounds = [lit(bound)._to_sql() for bound in partition_bounds] + + if any(bound is None for bound in partition_bounds): + warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.") + return self._single_scan_task(pushdowns, total_rows, total_size) + + size_bytes = None if strategy == "min_max" else math.ceil(total_size / num_scan_tasks) scan_tasks = [] - offset = 0 for i in range(num_scan_tasks): - limit = num_rows_per_scan_task - if i < num_scan_tasks_with_extra_row: - limit += 1 + left_bound = None if i == 0 else f"{self._partition_col} > {partition_bounds[i - 1]}" + right_bound = None if i == num_scan_tasks - 1 else f"{self._partition_col} <= {partition_bounds[i]}" + file_format_config = FileFormatConfig.from_database_config( DatabaseSourceConfig( - self.sql, limit=limit, offset=offset, apply_limit_before_offset=self._apply_limit_before_offset + self.sql, + left_bound=left_bound, + right_bound=right_bound, ) ) + scan_tasks.append( ScanTask.sql_scan_task( url=self.url, file_format=file_format_config, schema=self._schema._schema, - num_rows=limit, + num_rows=None, storage_config=self.storage_config, - size_bytes=math.ceil(limit * estimate_row_size_bytes), + size_bytes=size_bytes, pushdowns=pushdowns, ) ) - offset += limit return iter(scan_tasks) @@ -137,6 +189,3 @@ def can_absorb_limit(self) -> bool: def can_absorb_select(self) -> bool: return False - - def can_partition_read(self) -> bool: - return self._limit_and_offset_supported diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 291450a25a..91a63f8738 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -233,9 +233,8 @@ def read_sql( pa_table = SQLReader( sql, url, - limit=sql_options.limit, - offset=sql_options.offset, - apply_limit_before_offset=sql_options.apply_limit_before_offset, + left_bound=sql_options.left_bound, + right_bound=sql_options.right_bound, # TODO(Colin): Enable pushdowns projection=None, predicate=None, diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index 4e4fa3de87..598f4b1625 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -343,6 +343,10 @@ impl PyDataType { } } + pub fn is_numeric(&self) -> PyResult { + Ok(self.dtype.is_numeric()) + } + pub fn is_image(&self) -> PyResult { Ok(self.dtype.is_image()) } diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 983029e80c..b8a1149f5f 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -227,7 +227,7 @@ impl LiteralValue { display_timestamp(*val, tu, tz).replace('T', " ") ) .into(), - // TODO(Colin): Implement the rest of the types in future work for SQl pushdowns. + // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. Decimal(..) | Series(..) | Time(..) => None, #[cfg(feature = "python")] Python(..) => None, diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a4e8c7091a..d23e3def34 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -324,6 +324,10 @@ impl PyExpr { Ok(self.expr.name()?) } + pub fn to_sql(&self) -> PyResult> { + Ok(self.expr.to_sql()) + } + pub fn to_field(&self, schema: &PySchema) -> PyResult { Ok(self.expr.to_field(&schema.schema)?.into()) } diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index b87b32bffb..24d30d8c2e 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -310,9 +310,8 @@ fn materialize_scan_task( })?, FileFormatConfig::Database(DatabaseSourceConfig { sql, - limit, - offset, - apply_limit_before_offset, + left_bound, + right_bound, }) => { let predicate_sql = scan_task .pushdowns @@ -330,9 +329,8 @@ fn materialize_scan_task( py, sql, url, - *limit, - *offset, - *apply_limit_before_offset, + left_bound.as_deref(), + right_bound.as_deref(), predicate_sql.clone(), predicate_expr.clone(), scan_task.schema.clone().into(), diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 8fa458aa50..bc7325e9b4 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -792,9 +792,8 @@ pub(crate) fn read_sql_into_py_table( py: Python, sql: &str, url: &str, - limit: Option, - offset: Option, - apply_limit_before_offset: Option, + left_bound: Option<&str>, + right_bound: Option<&str>, predicate_sql: Option, predicate_expr: Option, schema: PySchema, @@ -818,13 +817,7 @@ pub(crate) fn read_sql_into_py_table( let sql_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadSQLOptions"))? - .call1(( - limit, - offset, - apply_limit_before_offset, - predicate_sql, - predicate_pyexpr, - ))?; + .call1((left_bound, right_bound, predicate_sql, predicate_pyexpr))?; let read_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadOptions"))? diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index 88e9b0a891..8ee1ca158c 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -263,40 +263,31 @@ impl_bincode_py_state_serialization!(JsonSourceConfig); #[cfg_attr(feature = "python", pyclass(module = "daft.daft", get_all))] pub struct DatabaseSourceConfig { pub sql: String, - pub limit: Option, - pub offset: Option, - pub apply_limit_before_offset: Option, + pub left_bound: Option, + pub right_bound: Option, } impl DatabaseSourceConfig { pub fn new_internal( sql: String, - limit: Option, - offset: Option, - apply_limit_before_offset: Option, + left_bound: Option, + right_bound: Option, ) -> Self { Self { sql, - limit, - offset, - apply_limit_before_offset, + left_bound, + right_bound, } } pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("SQL = {}", self.sql)); - if let Some(limit) = self.limit { - res.push(format!("Limit = {}", limit)); + if let Some(left_bound) = &self.left_bound { + res.push(format!("Left bound = {}", left_bound)); } - if let Some(offset) = self.offset { - res.push(format!("Offset = {}", offset)); - } - if let Some(apply_limit_before_offset) = self.apply_limit_before_offset { - res.push(format!( - "Limit before offset = {}", - apply_limit_before_offset - )); + if let Some(right_bound) = &self.right_bound { + res.push(format!("Right bound = {}", right_bound)); } res } @@ -307,13 +298,12 @@ impl DatabaseSourceConfig { impl DatabaseSourceConfig { /// Create a config for a Database data source. #[new] - fn new( - sql: &str, - limit: Option, - offset: Option, - apply_limit_before_offset: Option, - ) -> Self { - Self::new_internal(sql.to_string(), limit, offset, apply_limit_before_offset) + fn new(sql: &str, left_bound: Option<&str>, right_bound: Option<&str>) -> Self { + Self::new_internal( + sql.to_string(), + left_bound.map(str::to_string), + right_bound.map(str::to_string), + ) } } diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index d730528dc3..3353813f95 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -139,7 +139,7 @@ pub enum DataFileSource { path: String, chunk_spec: Option, size_bytes: Option, - metadata: TableMetadata, + metadata: Option, partition_spec: Option, statistics: Option, }, @@ -172,10 +172,9 @@ impl DataFileSource { pub fn get_metadata(&self) -> Option<&TableMetadata> { match self { - Self::AnonymousDataFile { metadata, .. } => metadata.as_ref(), - Self::CatalogDataFile { metadata, .. } | Self::DatabaseDataSource { metadata, .. } => { - Some(metadata) - } + Self::AnonymousDataFile { metadata, .. } + | Self::DatabaseDataSource { metadata, .. } => metadata.as_ref(), + Self::CatalogDataFile { metadata, .. } => Some(metadata), } } @@ -205,6 +204,14 @@ impl DataFileSource { metadata, partition_spec, statistics, + } + | Self::DatabaseDataSource { + path, + chunk_spec, + size_bytes, + metadata, + partition_spec, + statistics, } => { res.push(format!("Path = {}", path)); if let Some(chunk_spec) = chunk_spec { @@ -262,38 +269,6 @@ impl DataFileSource { res.push(format!("Statistics = {}", statistics)); } } - Self::DatabaseDataSource { - path, - chunk_spec, - size_bytes, - metadata, - partition_spec, - statistics, - } => { - res.push(format!("Path = {}", path)); - if let Some(chunk_spec) = chunk_spec { - res.push(format!( - "Chunk spec = {{ {} }}", - chunk_spec.multiline_display().join(", ") - )); - } - if let Some(size_bytes) = size_bytes { - res.push(format!("Size bytes = {}", size_bytes)); - } - res.push(format!( - "Metadata = {}", - metadata.multiline_display().join(", ") - )); - if let Some(partition_spec) = partition_spec { - res.push(format!( - "Partition spec = {}", - partition_spec.multiline_display().join(", ") - )); - } - if let Some(statistics) = statistics { - res.push(format!("Statistics = {}", statistics)); - } - } } res } diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 548d671b5d..8d68616386 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -324,8 +324,8 @@ pub mod pylib { url: String, file_format: PyFileFormatConfig, schema: PySchema, - num_rows: i64, storage_config: PyStorageConfig, + num_rows: Option, size_bytes: Option, pushdowns: Option, ) -> PyResult { @@ -333,9 +333,7 @@ pub mod pylib { path: url, chunk_spec: None, size_bytes, - metadata: TableMetadata { - length: num_rows as usize, - }, + metadata: num_rows.map(|n| TableMetadata { length: n as usize }), partition_spec: None, statistics: None, }; diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index af7018e224..c6285faae6 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -45,7 +45,8 @@ def generated_data(request: pytest.FixtureRequest) -> pd.DataFrame: "time_col": [ (datetime.combine(datetime.today(), time(0, 0)) + timedelta(minutes=x)).time() for x in range(200) ], - "null_col": [None if i % 2 == 1 else f"not_null_{i}" for i in range(num_rows)], + "null_col": [None if i % 2 == 1 else f"not_null" for i in range(num_rows)], + "non_uniformly_distributed_col": [1 for _ in range(num_rows)], } return pd.DataFrame(data) @@ -79,6 +80,7 @@ def create_and_populate(engine: Engine, data: pd.DataFrame) -> None: Column("bool_col", Boolean), Column("date_col", Date), Column("null_col", String(50)), + Column("non_uniformly_distributed_col", Integer), ) metadata.create_all(engine) data.to_sql(table.name, con=engine, if_exists="replace", index=False) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 13759f63e6..2b09c66366 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -29,20 +29,53 @@ def test_sql_partitioned_read(test_db, num_partitions) -> None: num_rows_per_partition = len(pdf) / num_partitions set_execution_config(read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition)) - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} ORDER BY id", test_db) + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id") assert df.num_partitions() == num_partitions assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() -@pytest.mark.parametrize("num_partitions", [2, 3, 4]) -def test_sql_partitioned_read_with_custom_num_partitions(test_db, num_partitions) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME} ORDER BY id", test_db, num_partitions=num_partitions) +@pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) +@pytest.mark.parametrize("partition_col", ["id", "float_col", "date_col"]) +def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( + test_db, num_partitions, partition_col +) -> None: + df = daft.read_sql( + f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col=partition_col, num_partitions=num_partitions + ) assert df.num_partitions() == num_partitions pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) assert_df_equals(df.to_pandas(), pdf, sort_key="id") +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) +def test_sql_partitioned_read_with_non_uniformly_distributed_column(test_db, num_partitions) -> None: + df = daft.read_sql( + f"SELECT * FROM {TEST_TABLE_NAME}", + test_db, + partition_col="non_uniformly_distributed_col", + num_partitions=num_partitions, + ) + assert df.num_partitions() == num_partitions + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("partition_col", ["string_col", "time_col", "null_col"]) +def test_sql_partitioned_read_with_non_partionable_column(test_db, partition_col) -> None: + with pytest.raises(ValueError, match="Failed to get partition bounds"): + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col=partition_col, num_partitions=2) + df = df.collect() + + +@pytest.mark.integration() +def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: + with pytest.raises(ValueError, match="Failed to execute sql"): + daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, num_partitions=2) + + @pytest.mark.integration() @pytest.mark.parametrize( "operator", @@ -64,8 +97,9 @@ def test_sql_partitioned_read_with_custom_num_partitions(test_db, num_partitions # ("time_col", datetime.time(10, 0, 0)), ], ) -def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) if operator == ">": @@ -91,49 +125,64 @@ def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value) @pytest.mark.integration() -def test_sql_read_with_is_null_filter_pushdowns(test_db) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_is_null_filter_pushdowns(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where(df["null_col"].is_null()) - pydict = df.to_pydict() - assert all(value is None for value in pydict["null_col"]) + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pdf[pdf["null_col"].isnull()] + + assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() -def test_sql_read_with_not_null_filter_pushdowns(test_db) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where(df["null_col"].not_null()) - pydict = df.to_pydict() - assert all(value is not None for value in pydict["null_col"]) + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pdf[pdf["null_col"].notnull()] + + assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() -def test_sql_read_with_if_else_filter_pushdown(test_db) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50)) - pydict = df.to_pydict() - assert all(value < 50 or value > 150 for value in pydict["float_col"]) + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)] + + assert_df_equals(df.to_pandas(), pdf, sort_key="id") @pytest.mark.integration() -def test_sql_read_with_all_pushdowns(test_db) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_all_pushdowns(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.where(~(df["id"] < 1)) df = df.where(df["string_col"].is_in([f"row_{i}" for i in range(10)])) - df = df.select(df["id"], df["string_col"]) - df = df.limit(1) + df = df.select(df["id"], df["float_col"], df["string_col"]) + df = df.limit(5) df = df.collect() - assert df.column_names == ["id", "string_col"] - assert len(df) == 1 + assert len(df) == 5 + assert df.column_names == ["id", "float_col", "string_col"] + + pydict = df.to_pydict() + assert all(i >= 1 for i in pydict["id"]) + assert all(s in [f"row_{i}" for i in range(10)] for s in pydict["string_col"]) @pytest.mark.integration() @pytest.mark.parametrize("limit", [0, 1, 10, 100, 200]) -def test_sql_read_with_limit_pushdown(test_db, limit) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_limit_pushdown(test_db, limit, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.limit(limit) df = df.collect() @@ -141,8 +190,9 @@ def test_sql_read_with_limit_pushdown(test_db, limit) -> None: @pytest.mark.integration() -def test_sql_read_with_projection_pushdown(test_db, generated_data) -> None: - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_projection_pushdown(test_db, generated_data, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) df = df.select(df["id"], df["string_col"]) df = df.collect() From 2ca97fda3ef014618657e7a1ee51c242e68ea782 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 5 Mar 2024 21:51:24 -0800 Subject: [PATCH 26/30] refactor --- daft/daft.pyi | 11 +++-- daft/expressions/expressions.py | 3 -- daft/io/_sql.py | 4 +- daft/runners/partitioning.py | 2 + daft/sql/sql_reader.py | 15 +++--- daft/sql/sql_scan.py | 48 ++++++++----------- daft/table/table_io.py | 1 + src/daft-dsl/src/python.rs | 4 -- src/daft-micropartition/src/micropartition.rs | 8 +++- src/daft-micropartition/src/python.rs | 13 +++-- src/daft-scan/src/file_format.rs | 30 ++++++++---- 11 files changed, 78 insertions(+), 61 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index adfd063f14..14c0dbddbf 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -223,14 +223,16 @@ class DatabaseSourceConfig: """ sql: str - left_bound: str | None - right_bound: str | None + partition_col: str | None + left_bound: PyExpr | None + right_bound: PyExpr | None def __init__( self, sql: str, - left_bound: str | None = None, - right_bound: str | None = None, + partition_col: str | None = None, + left_bound: PyExpr | None = None, + right_bound: PyExpr | None = None, ): ... class FileFormatConfig: @@ -912,7 +914,6 @@ class PyExpr: def not_null(self) -> PyExpr: ... def is_in(self, other: PyExpr) -> PyExpr: ... def name(self) -> str: ... - def to_sql(self) -> str | None: ... def to_field(self, schema: PySchema) -> PyField: ... def __repr__(self) -> str: ... def __hash__(self) -> int: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 09b3ad6fc6..89e4ca8728 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -473,9 +473,6 @@ def name(self) -> builtins.str: def __repr__(self) -> builtins.str: return repr(self._expr) - def _to_sql(self) -> builtins.str | None: - return self._expr.to_sql() - def _to_field(self, schema: Schema) -> Field: return Field._from_pyfield(self._expr.to_field(schema._schema)) diff --git a/daft/io/_sql.py b/daft/io/_sql.py index c813f5d9fe..b140e42d51 100644 --- a/daft/io/_sql.py +++ b/daft/io/_sql.py @@ -22,8 +22,8 @@ def read_sql( .. NOTE:: If partition_col is specified, this function will partition the query by the specified column. You may specify the number of partitions, or let Daft determine the number of partitions. - Daft will attempt to partition the query on the percentiles of the specified column, and will attempt to balance the number of rows in each partition. - If the database does not support the necessary SQL syntax to partition the query, Daft will partition the query via ranges between the min and max values of the specified column. + Daft will first calculate percentiles of the specified column. For example if num_partitions is 3, Daft will calculate the 33rd and 66th percentiles of the specified column, and use these values to partition the query. + If the database does not support the necessary SQL syntax to calculate percentiles, Daft will calculate the min and max of the specified column and partition the query into equal ranges. Args: sql (str): SQL query to execute diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 87445fde89..7333564b9b 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -77,6 +77,7 @@ class TableReadSQLOptions: """Options for parsing SQL tables Args: + partition_col: Column to use for partitioning the table left_bound: Lower bound of the table to read right_bound: Upper bound of the table to read @@ -84,6 +85,7 @@ class TableReadSQLOptions: predicate_expression: Expression predicate to apply to the table """ + partition_col: str | None = None left_bound: str | None = None right_bound: str | None = None diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index 565ac91e2e..79fff805f7 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -15,6 +15,7 @@ def __init__( url: str, left_bound: str | None = None, right_bound: str | None = None, + partition_col: str | None = None, limit: int | None = None, projection: list[str] | None = None, predicate: str | None = None, @@ -24,6 +25,7 @@ def __init__( self.url = url self.left_bound = left_bound self.right_bound = right_bound + self.partition_col = partition_col self.limit = limit self.projection = projection self.predicate = predicate @@ -34,12 +36,13 @@ def read(self) -> pa.Table: def _construct_sql_query(self) -> str: base_query = f"SELECT * FROM ({self.sql}) AS subquery" - if self.left_bound is not None and self.right_bound is not None: - base_query = f"{base_query} WHERE {self.left_bound} AND {self.right_bound}" - elif self.left_bound is not None: - base_query = f"{base_query} WHERE {self.left_bound}" - elif self.right_bound is not None: - base_query = f"{base_query} WHERE {self.right_bound}" + if self.partition_col is not None: + if self.left_bound is not None and self.right_bound is not None: + base_query = f"{base_query} WHERE {self.partition_col} > {self.left_bound} AND {self.partition_col} <= {self.right_bound}" + elif self.left_bound is not None: + base_query = f"{base_query} WHERE {self.partition_col} > {self.left_bound}" + elif self.right_bound is not None: + base_query = f"{base_query} WHERE {self.partition_col} <= {self.right_bound}" clauses = [] if self.projection is not None: diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index f95e2c4488..a40a64e34c 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -2,7 +2,6 @@ import logging import math -import warnings from collections.abc import Iterator from typing import Any @@ -45,7 +44,7 @@ def _attempt_schema_read(self) -> Schema: schema = Schema.from_pyarrow_schema(pa_table.schema) return schema except Exception: - # If both attempts fail, read without limit and offset + # If limit fails, try to read the entire table pa_table = SQLReader(self.sql, self.url).read() schema = Schema.from_pyarrow_schema(pa_table.schema) return schema @@ -110,29 +109,14 @@ def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[ ).read() min_val = pa_table.column(0)[0].as_py() max_val = pa_table.column(1)[0].as_py() - return [min_val + (max_val - min_val) * i / num_scan_tasks for i in range(1, num_scan_tasks)], "min_max" + range_size = (max_val - min_val) / num_scan_tasks + return [min_val + range_size * i for i in range(1, num_scan_tasks)], "min_max" except Exception: raise ValueError( f"Failed to get partition bounds from {self._partition_col}. Please ensure that the column exists, and is numeric or temporal." ) - def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: - file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) - return iter( - [ - ScanTask.sql_scan_task( - url=self.url, - file_format=file_format_config, - schema=self._schema._schema, - num_rows=total_rows, - storage_config=self.storage_config, - size_bytes=math.ceil(total_size), - pushdowns=pushdowns, - ) - ] - ) - def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: total_rows = self._get_num_rows() estimate_row_size_bytes = self.schema().estimate_row_size_bytes() @@ -144,24 +128,32 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: ) if num_scan_tasks == 1 or self._partition_col is None: - return self._single_scan_task(pushdowns, total_rows, total_size) + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) + return iter( + [ + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + num_rows=total_rows, + storage_config=self.storage_config, + size_bytes=math.ceil(total_size), + pushdowns=pushdowns, + ) + ] + ) partition_bounds, strategy = self._get_partition_bounds_and_strategy(num_scan_tasks) - partition_bounds = [lit(bound)._to_sql() for bound in partition_bounds] - - if any(bound is None for bound in partition_bounds): - warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.") - return self._single_scan_task(pushdowns, total_rows, total_size) - size_bytes = None if strategy == "min_max" else math.ceil(total_size / num_scan_tasks) scan_tasks = [] for i in range(num_scan_tasks): - left_bound = None if i == 0 else f"{self._partition_col} > {partition_bounds[i - 1]}" - right_bound = None if i == num_scan_tasks - 1 else f"{self._partition_col} <= {partition_bounds[i]}" + left_bound = None if i == 0 else lit(partition_bounds[i - 1])._expr + right_bound = None if i == num_scan_tasks - 1 else lit(partition_bounds[i])._expr file_format_config = FileFormatConfig.from_database_config( DatabaseSourceConfig( self.sql, + partition_col=self._partition_col, left_bound=left_bound, right_bound=right_bound, ) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 91a63f8738..16f552b4ad 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -233,6 +233,7 @@ def read_sql( pa_table = SQLReader( sql, url, + partition_col=sql_options.partition_col, left_bound=sql_options.left_bound, right_bound=sql_options.right_bound, # TODO(Colin): Enable pushdowns diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index d23e3def34..a4e8c7091a 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -324,10 +324,6 @@ impl PyExpr { Ok(self.expr.name()?) } - pub fn to_sql(&self) -> PyResult> { - Ok(self.expr.to_sql()) - } - pub fn to_field(&self, schema: &PySchema) -> PyResult { Ok(self.expr.to_field(&schema.schema)?.into()) } diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 24d30d8c2e..f86c977e5e 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -310,6 +310,7 @@ fn materialize_scan_task( })?, FileFormatConfig::Database(DatabaseSourceConfig { sql, + partition_col, left_bound, right_bound, }) => { @@ -323,14 +324,17 @@ fn materialize_scan_task( .filters .as_ref() .map(|p| (*p.as_ref()).clone().into()); + let left_bound = left_bound.as_ref().and_then(|lb| lb.to_sql()); + let right_bound = right_bound.as_ref().and_then(|rb| rb.to_sql()); Python::with_gil(|py| { urls.map(|url| { crate::python::read_sql_into_py_table( py, sql, url, - left_bound.as_deref(), - right_bound.as_deref(), + partition_col.clone(), + left_bound.clone(), + right_bound.clone(), predicate_sql.clone(), predicate_expr.clone(), scan_task.schema.clone().into(), diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index bc7325e9b4..17181d73ca 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -792,8 +792,9 @@ pub(crate) fn read_sql_into_py_table( py: Python, sql: &str, url: &str, - left_bound: Option<&str>, - right_bound: Option<&str>, + partition_col: Option, + left_bound: Option, + right_bound: Option, predicate_sql: Option, predicate_expr: Option, schema: PySchema, @@ -817,7 +818,13 @@ pub(crate) fn read_sql_into_py_table( let sql_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadSQLOptions"))? - .call1((left_bound, right_bound, predicate_sql, predicate_pyexpr))?; + .call1(( + partition_col, + left_bound, + right_bound, + predicate_sql, + predicate_pyexpr, + ))?; let read_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadOptions"))? diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index 8ee1ca158c..1fe15ab972 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -1,11 +1,13 @@ use common_error::{DaftError, DaftResult}; use daft_core::{datatypes::TimeUnit, impl_bincode_py_state_serialization}; +use daft_dsl::ExprRef; use serde::{Deserialize, Serialize}; use std::{str::FromStr, sync::Arc}; #[cfg(feature = "python")] use { daft_core::python::datatype::PyTimeUnit, + daft_dsl::python::PyExpr, pyo3::{ pyclass, pyclass::CompareOp, pymethods, types::PyBytes, IntoPy, PyObject, PyResult, PyTypeInfo, Python, ToPyObject, @@ -260,21 +262,24 @@ impl_bincode_py_state_serialization!(JsonSourceConfig); /// Configuration for a Database data source. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] -#[cfg_attr(feature = "python", pyclass(module = "daft.daft", get_all))] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] pub struct DatabaseSourceConfig { pub sql: String, - pub left_bound: Option, - pub right_bound: Option, + pub partition_col: Option, + pub left_bound: Option, + pub right_bound: Option, } impl DatabaseSourceConfig { pub fn new_internal( sql: String, - left_bound: Option, - right_bound: Option, + partition_col: Option, + left_bound: Option, + right_bound: Option, ) -> Self { Self { sql, + partition_col, left_bound, right_bound, } @@ -283,6 +288,9 @@ impl DatabaseSourceConfig { pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("SQL = {}", self.sql)); + if let Some(partition_col) = &self.partition_col { + res.push(format!("Partition column = {}", partition_col)); + } if let Some(left_bound) = &self.left_bound { res.push(format!("Left bound = {}", left_bound)); } @@ -298,11 +306,17 @@ impl DatabaseSourceConfig { impl DatabaseSourceConfig { /// Create a config for a Database data source. #[new] - fn new(sql: &str, left_bound: Option<&str>, right_bound: Option<&str>) -> Self { + fn new( + sql: &str, + partition_col: Option<&str>, + left_bound: Option, + right_bound: Option, + ) -> Self { Self::new_internal( sql.to_string(), - left_bound.map(str::to_string), - right_bound.map(str::to_string), + partition_col.map(|s| s.to_string()), + left_bound.map(|e| e.expr.into()), + right_bound.map(|e| e.expr.into()), ) } } From 7ac413174baf8ed8c1b036105ca3a9fc270ce9d8 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 5 Mar 2024 21:59:22 -0800 Subject: [PATCH 27/30] refactor some string args --- src/daft-micropartition/src/micropartition.rs | 8 ++++---- src/daft-micropartition/src/python.rs | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index f86c977e5e..5b30f003ad 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -332,10 +332,10 @@ fn materialize_scan_task( py, sql, url, - partition_col.clone(), - left_bound.clone(), - right_bound.clone(), - predicate_sql.clone(), + partition_col.as_deref(), + left_bound.as_deref(), + right_bound.as_deref(), + predicate_sql.as_deref(), predicate_expr.clone(), scan_task.schema.clone().into(), scan_task diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 17181d73ca..9222b469fa 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -792,10 +792,10 @@ pub(crate) fn read_sql_into_py_table( py: Python, sql: &str, url: &str, - partition_col: Option, - left_bound: Option, - right_bound: Option, - predicate_sql: Option, + partition_col: Option<&str>, + left_bound: Option<&str>, + right_bound: Option<&str>, + predicate_sql: Option<&str>, predicate_expr: Option, schema: PySchema, include_columns: Option>, From 5be0cd28ebc8d5791809bc77d58b2b045675fcbb Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 6 Mar 2024 14:23:50 -0800 Subject: [PATCH 28/30] add datetime support --- daft/daft.pyi | 12 +--- daft/datatype.py | 2 + daft/expressions/expressions.py | 3 + daft/runners/partitioning.py | 8 --- daft/sql/sql_reader.py | 17 +----- daft/sql/sql_scan.py | 57 +++++++++++-------- daft/table/table_io.py | 3 - src/daft-dsl/src/lit.rs | 19 ++++--- src/daft-dsl/src/python.rs | 4 ++ src/daft-micropartition/src/micropartition.rs | 12 +--- src/daft-micropartition/src/python.rs | 11 +--- src/daft-scan/src/file_format.rs | 42 ++------------ tests/integration/sql/conftest.py | 6 +- tests/integration/sql/test_sql.py | 20 ++++--- tests/io/test_csv_roundtrip.py | 6 ++ tests/io/test_parquet_roundtrip.py | 1 + tests/table/test_from_py.py | 2 + 17 files changed, 87 insertions(+), 138 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index 14c0dbddbf..ff20ef1e2b 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -223,17 +223,8 @@ class DatabaseSourceConfig: """ sql: str - partition_col: str | None - left_bound: PyExpr | None - right_bound: PyExpr | None - def __init__( - self, - sql: str, - partition_col: str | None = None, - left_bound: PyExpr | None = None, - right_bound: PyExpr | None = None, - ): ... + def __init__(self, sql: str): ... class FileFormatConfig: """ @@ -915,6 +906,7 @@ class PyExpr: def is_in(self, other: PyExpr) -> PyExpr: ... def name(self) -> str: ... def to_field(self, schema: PySchema) -> PyField: ... + def to_sql(self) -> str | None: ... def __repr__(self) -> str: ... def __hash__(self) -> int: ... def __reduce__(self) -> tuple: ... diff --git a/daft/datatype.py b/daft/datatype.py index 81114698c4..6d7e25a6ce 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -375,6 +375,8 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: return cls.decimal128(arrow_type.precision, arrow_type.scale) elif pa.types.is_date32(arrow_type): return cls.date() + elif pa.types.is_date64(arrow_type): + return cls.timestamp(TimeUnit.ms()) elif pa.types.is_time64(arrow_type): timeunit = TimeUnit.from_str(pa.type_for_alias(str(arrow_type)).unit) return cls.time(timeunit) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 89e4ca8728..09b3ad6fc6 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -473,6 +473,9 @@ def name(self) -> builtins.str: def __repr__(self) -> builtins.str: return repr(self._expr) + def _to_sql(self) -> builtins.str | None: + return self._expr.to_sql() + def _to_field(self, schema: Schema) -> Field: return Field._from_pyfield(self._expr.to_field(schema._schema)) diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 7333564b9b..2c42dae7ca 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -77,18 +77,10 @@ class TableReadSQLOptions: """Options for parsing SQL tables Args: - partition_col: Column to use for partitioning the table - left_bound: Lower bound of the table to read - right_bound: Upper bound of the table to read - predicate_sql: SQL predicate to apply to the table predicate_expression: Expression predicate to apply to the table """ - partition_col: str | None = None - left_bound: str | None = None - right_bound: str | None = None - predicate_sql: str | None = None predicate_expression: Expression | None = None diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index 79fff805f7..111af64622 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -13,9 +13,6 @@ def __init__( self, sql: str, url: str, - left_bound: str | None = None, - right_bound: str | None = None, - partition_col: str | None = None, limit: int | None = None, projection: list[str] | None = None, predicate: str | None = None, @@ -23,9 +20,6 @@ def __init__( self.sql = sql self.url = url - self.left_bound = left_bound - self.right_bound = right_bound - self.partition_col = partition_col self.limit = limit self.projection = projection self.predicate = predicate @@ -35,22 +29,13 @@ def read(self) -> pa.Table: return self._execute_sql_query(sql) def _construct_sql_query(self) -> str: - base_query = f"SELECT * FROM ({self.sql}) AS subquery" - if self.partition_col is not None: - if self.left_bound is not None and self.right_bound is not None: - base_query = f"{base_query} WHERE {self.partition_col} > {self.left_bound} AND {self.partition_col} <= {self.right_bound}" - elif self.left_bound is not None: - base_query = f"{base_query} WHERE {self.partition_col} > {self.left_bound}" - elif self.right_bound is not None: - base_query = f"{base_query} WHERE {self.partition_col} <= {self.right_bound}" - clauses = [] if self.projection is not None: clauses.append(f"SELECT {', '.join(self.projection)}") else: clauses.append("SELECT *") - clauses.append(f"FROM ({base_query}) AS subquery") + clauses.append(f"FROM ({self.sql}) AS subquery") if self.predicate is not None: clauses.append(f"WHERE {self.predicate}") diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index a40a64e34c..7ac4502419 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -2,6 +2,7 @@ import logging import math +import warnings from collections.abc import Iterator from typing import Any @@ -117,6 +118,22 @@ def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[ f"Failed to get partition bounds from {self._partition_col}. Please ensure that the column exists, and is numeric or temporal." ) + def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) + return iter( + [ + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + num_rows=total_rows, + storage_config=self.storage_config, + size_bytes=math.ceil(total_size), + pushdowns=pushdowns, + ) + ] + ) + def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: total_rows = self._get_num_rows() estimate_row_size_bytes = self.schema().estimate_row_size_bytes() @@ -128,36 +145,26 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: ) if num_scan_tasks == 1 or self._partition_col is None: - file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) - return iter( - [ - ScanTask.sql_scan_task( - url=self.url, - file_format=file_format_config, - schema=self._schema._schema, - num_rows=total_rows, - storage_config=self.storage_config, - size_bytes=math.ceil(total_size), - pushdowns=pushdowns, - ) - ] - ) + return self._single_scan_task(pushdowns, total_rows, total_size) partition_bounds, strategy = self._get_partition_bounds_and_strategy(num_scan_tasks) + partition_bounds = [lit(bound)._to_sql() for bound in partition_bounds] + + if any(bound is None for bound in partition_bounds): + warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.") + return self._single_scan_task(pushdowns, total_rows, total_size) + size_bytes = None if strategy == "min_max" else math.ceil(total_size / num_scan_tasks) scan_tasks = [] for i in range(num_scan_tasks): - left_bound = None if i == 0 else lit(partition_bounds[i - 1])._expr - right_bound = None if i == num_scan_tasks - 1 else lit(partition_bounds[i])._expr - - file_format_config = FileFormatConfig.from_database_config( - DatabaseSourceConfig( - self.sql, - partition_col=self._partition_col, - left_bound=left_bound, - right_bound=right_bound, - ) - ) + if i == 0: + sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} <= {partition_bounds[i]}" + elif i == num_scan_tasks - 1: + sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} > {partition_bounds[i - 1]}" + else: + sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} > {partition_bounds[i - 1]} AND {self._partition_col} <= {partition_bounds[i]}" + + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(sql=sql)) scan_tasks.append( ScanTask.sql_scan_task( diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 16f552b4ad..b492fae967 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -233,9 +233,6 @@ def read_sql( pa_table = SQLReader( sql, url, - partition_col=sql_options.partition_col, - left_bound=sql_options.left_bound, - right_bound=sql_options.right_bound, # TODO(Colin): Enable pushdowns projection=None, predicate=None, diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index b8a1149f5f..53c7e35f69 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -220,15 +220,20 @@ impl LiteralValue { self.to_string().into() } Utf8(val) => format!("'{}'", val).into(), - Binary(val) => format!("x'{}'", val.len()).into(), Date(val) => format!("DATE '{}'", display_date32(*val)).into(), - Timestamp(val, tu, tz) => format!( - "TIMESTAMP '{}'", - display_timestamp(*val, tu, tz).replace('T', " ") - ) - .into(), + Timestamp(val, tu, tz) => { + if tz.is_some() { + // Different databases have different ways of handling timezones, so there's no reliable way to convert this to SQL. + return None; + } + format!( + "TIMESTAMP '{}'", + display_timestamp(*val, tu, tz).replace('T', " ") + ) + .into() + } // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. - Decimal(..) | Series(..) | Time(..) => None, + Decimal(..) | Series(..) | Time(..) | Binary(..) => None, #[cfg(feature = "python")] Python(..) => None, } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a4e8c7091a..d23e3def34 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -324,6 +324,10 @@ impl PyExpr { Ok(self.expr.name()?) } + pub fn to_sql(&self) -> PyResult> { + Ok(self.expr.to_sql()) + } + pub fn to_field(&self, schema: &PySchema) -> PyResult { Ok(self.expr.to_field(&schema.schema)?.into()) } diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 5b30f003ad..090c095053 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -308,12 +308,7 @@ fn materialize_scan_task( }) .collect::>>() })?, - FileFormatConfig::Database(DatabaseSourceConfig { - sql, - partition_col, - left_bound, - right_bound, - }) => { + FileFormatConfig::Database(DatabaseSourceConfig { sql }) => { let predicate_sql = scan_task .pushdowns .filters @@ -324,17 +319,12 @@ fn materialize_scan_task( .filters .as_ref() .map(|p| (*p.as_ref()).clone().into()); - let left_bound = left_bound.as_ref().and_then(|lb| lb.to_sql()); - let right_bound = right_bound.as_ref().and_then(|rb| rb.to_sql()); Python::with_gil(|py| { urls.map(|url| { crate::python::read_sql_into_py_table( py, sql, url, - partition_col.as_deref(), - left_bound.as_deref(), - right_bound.as_deref(), predicate_sql.as_deref(), predicate_expr.clone(), scan_task.schema.clone().into(), diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 9222b469fa..0137078e95 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -792,9 +792,6 @@ pub(crate) fn read_sql_into_py_table( py: Python, sql: &str, url: &str, - partition_col: Option<&str>, - left_bound: Option<&str>, - right_bound: Option<&str>, predicate_sql: Option<&str>, predicate_expr: Option, schema: PySchema, @@ -818,13 +815,7 @@ pub(crate) fn read_sql_into_py_table( let sql_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadSQLOptions"))? - .call1(( - partition_col, - left_bound, - right_bound, - predicate_sql, - predicate_pyexpr, - ))?; + .call1((predicate_sql, predicate_pyexpr))?; let read_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadOptions"))? diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index 1fe15ab972..d680f89250 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -1,13 +1,11 @@ use common_error::{DaftError, DaftResult}; use daft_core::{datatypes::TimeUnit, impl_bincode_py_state_serialization}; -use daft_dsl::ExprRef; use serde::{Deserialize, Serialize}; use std::{str::FromStr, sync::Arc}; #[cfg(feature = "python")] use { daft_core::python::datatype::PyTimeUnit, - daft_dsl::python::PyExpr, pyo3::{ pyclass, pyclass::CompareOp, pymethods, types::PyBytes, IntoPy, PyObject, PyResult, PyTypeInfo, Python, ToPyObject, @@ -265,38 +263,16 @@ impl_bincode_py_state_serialization!(JsonSourceConfig); #[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] pub struct DatabaseSourceConfig { pub sql: String, - pub partition_col: Option, - pub left_bound: Option, - pub right_bound: Option, } impl DatabaseSourceConfig { - pub fn new_internal( - sql: String, - partition_col: Option, - left_bound: Option, - right_bound: Option, - ) -> Self { - Self { - sql, - partition_col, - left_bound, - right_bound, - } + pub fn new_internal(sql: String) -> Self { + Self { sql } } pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("SQL = {}", self.sql)); - if let Some(partition_col) = &self.partition_col { - res.push(format!("Partition column = {}", partition_col)); - } - if let Some(left_bound) = &self.left_bound { - res.push(format!("Left bound = {}", left_bound)); - } - if let Some(right_bound) = &self.right_bound { - res.push(format!("Right bound = {}", right_bound)); - } res } } @@ -306,18 +282,8 @@ impl DatabaseSourceConfig { impl DatabaseSourceConfig { /// Create a config for a Database data source. #[new] - fn new( - sql: &str, - partition_col: Option<&str>, - left_bound: Option, - right_bound: Option, - ) -> Self { - Self::new_internal( - sql.to_string(), - partition_col.map(|s| s.to_string()), - left_bound.map(|e| e.expr.into()), - right_bound.map(|e| e.expr.into()), - ) + fn new(sql: &str) -> Self { + Self::new_internal(sql.to_string()) } } diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index c6285faae6..46dc3e36ed 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -11,6 +11,7 @@ Boolean, Column, Date, + DateTime, Engine, Float, Integer, @@ -39,9 +40,7 @@ def generated_data(request: pytest.FixtureRequest) -> pd.DataFrame: "string_col": [f"row_{i}" for i in range(num_rows)], "bool_col": [True for _ in range(num_rows // 2)] + [False for _ in range(num_rows // 2)], "date_col": [date(2021, 1, 1) + timedelta(days=i) for i in range(num_rows)], - # TODO(Colin): ConnectorX parses datetime as pyarrow date64 type, which we currently cast to Python, causing our assertions to fail. - # One possible solution is to cast date64 into Timestamp("ms") in our from_arrow code. - # "date_time_col": [datetime(2020, 1, 1, 10, 0, 0) + timedelta(hours=i) for i in range(num_rows)], + "date_time_col": [datetime(2020, 1, 1, 10, 0, 0) + timedelta(hours=i) for i in range(num_rows)], "time_col": [ (datetime.combine(datetime.today(), time(0, 0)) + timedelta(minutes=x)).time() for x in range(200) ], @@ -79,6 +78,7 @@ def create_and_populate(engine: Engine, data: pd.DataFrame) -> None: Column("string_col", String(50)), Column("bool_col", Boolean), Column("date_col", Date), + Column("date_time_col", DateTime), Column("null_col", String(50)), Column("non_uniformly_distributed_col", Integer), ) diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index 2b09c66366..b0ac17b6a3 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -36,7 +36,7 @@ def test_sql_partitioned_read(test_db, num_partitions) -> None: @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) -@pytest.mark.parametrize("partition_col", ["id", "float_col", "date_col"]) +@pytest.mark.parametrize("partition_col", ["id", "float_col", "date_col", "date_time_col"]) def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( test_db, num_partitions, partition_col ) -> None: @@ -89,12 +89,7 @@ def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: ("string_col", "row_100"), ("bool_col", True), ("date_col", datetime.date(2021, 1, 1)), - # TODO(Colin) - ConnectorX parses datetime as pyarrow date64 type, which we currently cast to Python, causing our assertions to fail. - # One possible solution is to cast date64 into Timestamp("ms") in our from_arrow code. - # ("date_time_col", datetime.datetime(2020, 1, 1, 10, 0, 0)), - # TODO(Colin) - Reading time from Postgres is parsed as Time(Nanoseconds), while from MySQL it is parsed as Duration(Microseconds) - # Need to fix our time comparison code to handle this. - # ("time_col", datetime.time(10, 0, 0)), + ("date_time_col", datetime.datetime(2020, 1, 1, 10, 0, 0)), ], ) @pytest.mark.parametrize("num_partitions", [1, 2]) @@ -160,6 +155,17 @@ def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions) -> None: assert_df_equals(df.to_pandas(), pdf, sort_key="id") +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_is_in_filter_pushdown(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + df = df.where(df["id"].is_in([1, 2, 3])) + + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pdf[pdf["id"].isin([1, 2, 3])] + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) def test_sql_read_with_all_pushdowns(test_db, num_partitions) -> None: diff --git a/tests/io/test_csv_roundtrip.py b/tests/io/test_csv_roundtrip.py index 2043e30a1b..dd288e6806 100644 --- a/tests/io/test_csv_roundtrip.py +++ b/tests/io/test_csv_roundtrip.py @@ -50,6 +50,12 @@ # NOTE: Seems like the inferred type is seconds because it's written with seconds resolution DataType.timestamp(TimeUnit.s()), ), + ( + [datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], + pa.date64(), + DataType.timestamp(TimeUnit.ms()), + DataType.timestamp(TimeUnit.s()), + ), ( [datetime.timedelta(days=1), datetime.timedelta(days=2), None], pa.duration("ms"), diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 81569f1be2..1b8f8df619 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -42,6 +42,7 @@ pa.timestamp("ms"), DataType.timestamp(TimeUnit.ms()), ), + ([datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], pa.date64(), DataType.timestamp(TimeUnit.ms())), ( [datetime.timedelta(days=1), datetime.timedelta(days=2), None], pa.duration("ms"), diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index 95012309be..bc8d2be888 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -94,6 +94,7 @@ "binary": pa.array(PYTHON_TYPE_ARRAYS["binary"], pa.binary()), "boolean": pa.array(PYTHON_TYPE_ARRAYS["bool"], pa.bool_()), "date32": pa.array(PYTHON_TYPE_ARRAYS["date"], pa.date32()), + "date64": pa.array(PYTHON_TYPE_ARRAYS["date"], pa.date64()), "time64_microseconds": pa.array(PYTHON_TYPE_ARRAYS["time"], pa.time64("us")), "time64_nanoseconds": pa.array(PYTHON_TYPE_ARRAYS["time"], pa.time64("ns")), "list": pa.array(PYTHON_TYPE_ARRAYS["list"], pa.list_(pa.int64())), @@ -149,6 +150,7 @@ "binary": pa.large_binary(), "boolean": pa.bool_(), "date32": pa.date32(), + "date64": pa.timestamp("ms"), "time64_microseconds": pa.time64("us"), "time64_nanoseconds": pa.time64("ns"), "list": pa.large_list(pa.int64()), From 0a85439ac47d9e9ae41015d6da79c1d19198bd08 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 6 Mar 2024 14:45:08 -0800 Subject: [PATCH 29/30] comment about timestamp --- src/daft-dsl/src/lit.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 53c7e35f69..bb1f2c60af 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -222,10 +222,14 @@ impl LiteralValue { Utf8(val) => format!("'{}'", val).into(), Date(val) => format!("DATE '{}'", display_date32(*val)).into(), Timestamp(val, tu, tz) => { + // Different databases have different ways of handling timezones, so there's no reliable way to convert this to SQL. if tz.is_some() { - // Different databases have different ways of handling timezones, so there's no reliable way to convert this to SQL. return None; } + // Note: Our display_timestamp function returns a string in the ISO 8601 format "YYYY-MM-DDTHH:MM:SS.fffff". + // However, the ANSI SQL standard replaces the 'T' with a space. See: https://docs.actian.com/ingres/10s/index.html#page/SQLRef/Summary_of_ANSI_Date_2fTime_Data_Types.htm + // While many databases support the 'T', some such as Trino, do not. So we replace the 'T' with a space here. + // We also don't use the i64 unix timestamp directly, because different databases have different functions to convert unix timestamps to timestamps, e.g. Trino uses from_unixtime, while PostgreSQL uses to_timestamp. format!( "TIMESTAMP '{}'", display_timestamp(*val, tu, tz).replace('T', " ") From bd65b050a20e75913c4f643cbab47a7d994c59cc Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 13 Mar 2024 14:44:32 -0700 Subject: [PATCH 30/30] refactor and add limit pushdown --- daft/runners/partitioning.py | 2 - daft/sql/sql_reader.py | 17 +- daft/sql/sql_scan.py | 190 ++++++++++-------- daft/table/table_io.py | 28 ++- .../basic_concepts/read-and-write.rst | 17 ++ src/daft-dsl/src/expr.rs | 27 +-- src/daft-dsl/src/lit.rs | 23 ++- src/daft-micropartition/src/micropartition.rs | 6 - src/daft-micropartition/src/python.rs | 4 +- tests/integration/sql/test_sql.py | 6 + 10 files changed, 189 insertions(+), 131 deletions(-) diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 2c42dae7ca..35f3f0fb6d 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -77,11 +77,9 @@ class TableReadSQLOptions: """Options for parsing SQL tables Args: - predicate_sql: SQL predicate to apply to the table predicate_expression: Expression predicate to apply to the table """ - predicate_sql: str | None = None predicate_expression: Expression | None = None diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py index 111af64622..367fdcaeaf 100644 --- a/daft/sql/sql_reader.py +++ b/daft/sql/sql_reader.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import warnings from urllib.parse import urlparse import pyarrow as pa @@ -25,10 +26,15 @@ def __init__( self.predicate = predicate def read(self) -> pa.Table: - sql = self._construct_sql_query() - return self._execute_sql_query(sql) - - def _construct_sql_query(self) -> str: + try: + sql = self._construct_sql_query(apply_limit=True) + return self._execute_sql_query(sql) + except RuntimeError: + warnings.warn("Failed to execute the query with a limit, attempting to read the entire table.") + sql = self._construct_sql_query(apply_limit=False) + return self._execute_sql_query(sql) + + def _construct_sql_query(self, apply_limit: bool) -> str: clauses = [] if self.projection is not None: clauses.append(f"SELECT {', '.join(self.projection)}") @@ -40,7 +46,7 @@ def _construct_sql_query(self) -> str: if self.predicate is not None: clauses.append(f"WHERE {self.predicate}") - if self.limit is not None: + if self.limit is not None and apply_limit is True: clauses.append(f"LIMIT {self.limit}") return "\n".join(clauses) @@ -88,6 +94,7 @@ def _execute_sql_query_with_sqlalchemy(self, sql: str) -> pa.Table: rows = result.fetchall() pydict = {column_name: [row[i] for row in rows] for i, column_name in enumerate(result.keys())} + # TODO: Use type codes from cursor description to create pyarrow schema return pa.Table.from_pydict(pydict) except Exception as e: raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 7ac4502419..93e1b0ab56 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -4,6 +4,7 @@ import math import warnings from collections.abc import Iterator +from enum import Enum, auto from typing import Any from daft.context import get_context @@ -22,6 +23,11 @@ logger = logging.getLogger(__name__) +class PartitionBoundStrategy(Enum): + PERCENTILE = auto() + MIN_MAX = auto() + + class SQLScanOperator(ScanOperator): def __init__( self, @@ -39,25 +45,6 @@ def __init__( self._num_partitions = num_partitions self._schema = self._attempt_schema_read() - def _attempt_schema_read(self) -> Schema: - try: - pa_table = SQLReader(self.sql, self.url, limit=1).read() - schema = Schema.from_pyarrow_schema(pa_table.schema) - return schema - except Exception: - # If limit fails, try to read the entire table - pa_table = SQLReader(self.sql, self.url).read() - schema = Schema.from_pyarrow_schema(pa_table.schema) - return schema - - def _get_num_rows(self) -> int: - pa_table = SQLReader( - self.sql, - self.url, - projection=["COUNT(*)"], - ).read() - return pa_table.column(0)[0].as_py() - def schema(self) -> Schema: return self._schema @@ -73,67 +60,6 @@ def multiline_display(self) -> list[str]: f"Schema = {self._schema}", ] - def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], str]: - if self._partition_col is None: - raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.") - - if not ( - self._schema[self._partition_col].dtype._is_temporal_type() - or self._schema[self._partition_col].dtype._is_numeric_type() - ): - raise ValueError( - f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning." - ) - - try: - # try to get percentiles using percentile_cont - percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)] - pa_table = SQLReader( - self.sql, - self.url, - projection=[ - f"percentile_cont({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" - for i, percentile in enumerate(percentiles) - ], - ).read() - bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)] - return bounds, "percentile" - - except Exception as e: - # if the above fails, use the min and max of the partition column - logger.info("Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", e) - try: - pa_table = SQLReader( - self.sql, - self.url, - projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"], - ).read() - min_val = pa_table.column(0)[0].as_py() - max_val = pa_table.column(1)[0].as_py() - range_size = (max_val - min_val) / num_scan_tasks - return [min_val + range_size * i for i in range(1, num_scan_tasks)], "min_max" - - except Exception: - raise ValueError( - f"Failed to get partition bounds from {self._partition_col}. Please ensure that the column exists, and is numeric or temporal." - ) - - def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: - file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) - return iter( - [ - ScanTask.sql_scan_task( - url=self.url, - file_format=file_format_config, - schema=self._schema._schema, - num_rows=total_rows, - storage_config=self.storage_config, - size_bytes=math.ceil(total_size), - pushdowns=pushdowns, - ) - ] - ) - def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: total_rows = self._get_num_rows() estimate_row_size_bytes = self.schema().estimate_row_size_bytes() @@ -154,7 +80,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.") return self._single_scan_task(pushdowns, total_rows, total_size) - size_bytes = None if strategy == "min_max" else math.ceil(total_size / num_scan_tasks) + size_bytes = math.ceil(total_size / num_scan_tasks) if strategy == PartitionBoundStrategy.PERCENTILE else None scan_tasks = [] for i in range(num_scan_tasks): if i == 0: @@ -188,3 +114,105 @@ def can_absorb_limit(self) -> bool: def can_absorb_select(self) -> bool: return False + + def _attempt_schema_read(self) -> Schema: + pa_table = SQLReader(self.sql, self.url, limit=1).read() + schema = Schema.from_pyarrow_schema(pa_table.schema) + return schema + + def _get_num_rows(self) -> int: + pa_table = SQLReader( + self.sql, + self.url, + projection=["COUNT(*)"], + ).read() + + if pa_table.num_rows != 1: + raise RuntimeError( + "Failed to get the number of rows: COUNT(*) query returned an unexpected number of rows." + ) + if pa_table.num_columns != 1: + raise RuntimeError( + "Failed to get the number of rows: COUNT(*) query returned an unexpected number of columns." + ) + + return pa_table.column(0)[0].as_py() + + def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: + try: + # Try to get percentiles using percentile_cont + percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)] + pa_table = SQLReader( + self.sql, + self.url, + projection=[ + f"percentile_cont({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" + for i, percentile in enumerate(percentiles) + ], + ).read() + return pa_table, PartitionBoundStrategy.PERCENTILE + + except RuntimeError as e: + # If percentiles fails, use the min and max of the partition column + logger.info("Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", e) + + pa_table = SQLReader( + self.sql, + self.url, + projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"], + ).read() + return pa_table, PartitionBoundStrategy.MIN_MAX + + def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]: + if self._partition_col is None: + raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.") + + if not ( + self._schema[self._partition_col].dtype._is_temporal_type() + or self._schema[self._partition_col].dtype._is_numeric_type() + ): + raise ValueError( + f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning." + ) + + pa_table, strategy = self._attempt_partition_bounds_read(num_scan_tasks) + + if pa_table.num_rows != 1: + raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.") + + if strategy == PartitionBoundStrategy.PERCENTILE: + if pa_table.num_columns != num_scan_tasks - 1: + raise RuntimeError( + f"Failed to get partition bounds: expected {num_scan_tasks - 1} percentiles, but got {pa_table.num_columns}." + ) + + bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)] + + elif strategy == PartitionBoundStrategy.MIN_MAX: + if pa_table.num_columns != 2: + raise RuntimeError( + f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}." + ) + + min_val = pa_table.column(0)[0].as_py() + max_val = pa_table.column(1)[0].as_py() + range_size = (max_val - min_val) / num_scan_tasks + bounds = [min_val + range_size * i for i in range(1, num_scan_tasks)] + + return bounds, strategy + + def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) + return iter( + [ + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + num_rows=total_rows, + storage_config=self.storage_config, + size_bytes=math.ceil(total_size), + pushdowns=pushdowns, + ) + ] + ) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index b492fae967..5eb0fcb332 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -226,24 +226,38 @@ def read_sql( sql (str): SQL query to execute url (str): URL to the database schema (Schema): Daft schema to read the SQL query into + sql_options (TableReadSQLOptions, optional): SQL-specific configs to apply when reading the file + read_options (TableReadOptions, optional): Options for reading the file Returns: MicroPartition: MicroPartition from SQL query """ + + if sql_options.predicate_expression is not None: + # If the predicate can be translated to SQL, we can apply all pushdowns to the SQL query + predicate_sql = sql_options.predicate_expression._to_sql() + apply_pushdowns_to_sql = predicate_sql is not None + else: + # If we don't have a predicate, we can still apply the limit and projection to the SQL query + predicate_sql = None + apply_pushdowns_to_sql = True + pa_table = SQLReader( sql, url, - # TODO(Colin): Enable pushdowns - projection=None, - predicate=None, + limit=read_options.num_rows if apply_pushdowns_to_sql else None, + projection=read_options.column_names if apply_pushdowns_to_sql else None, + predicate=predicate_sql, ).read() mp = MicroPartition.from_arrow(pa_table) - if sql_options.predicate_expression is not None: - mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression])) + if len(mp) != 0 and apply_pushdowns_to_sql is False: + # If we have a non-empty table and we didn't apply pushdowns to SQL, we need to apply them in-memory + if sql_options.predicate_expression is not None: + mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression])) - if read_options.num_rows is not None: - mp = mp.head(read_options.num_rows) + if read_options.num_rows is not None: + mp = mp.head(read_options.num_rows) return _cast_table_to_schema(mp, read_options=read_options, schema=schema) diff --git a/docs/source/user_guide/basic_concepts/read-and-write.rst b/docs/source/user_guide/basic_concepts/read-and-write.rst index e4528cd5d8..2eb04cde22 100644 --- a/docs/source/user_guide/basic_concepts/read-and-write.rst +++ b/docs/source/user_guide/basic_concepts/read-and-write.rst @@ -72,6 +72,23 @@ For testing, or small datasets that fit in memory, you may also create DataFrame To learn more, consult the API documentation on :ref:`creating DataFrames from in-memory data structures `. +From Databases +^^^^^^^^^^^^^^ + +Daft can also read data from a variety of databases, including PostgreSQL, MySQL, Trino, and SQLite using the :func:`daft.read_sql` method. +In order to partition the data, you can specify a partition column, which will allow Daft to read the data in parallel. + +.. code:: python + + # Read from a PostgreSQL database + uri = "postgresql://user:password@host:port/database" + df = daft.read_sql(uri, "SELECT * FROM my_table") + + # Read with a partition column + df = daft.read_sql(uri, "SELECT * FROM my_table", partition_col="date") + +To learn more, consult the API documentation on :func:`daft.read_sql`. + Writing Data ------------ diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 2ebdb02424..6ad61c9434 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -574,28 +574,19 @@ impl Expr { fn to_sql_inner(expr: &Expr, buffer: &mut W) -> io::Result<()> { match expr { Expr::Column(name) => write!(buffer, "{}", name), - Expr::Literal(lit) => { - if let Some(s) = lit.to_sql() { - write!(buffer, "{}", s) - } else { - Err(io::Error::new( - io::ErrorKind::Other, - "Unsupported literal for SQL translation", - )) - } - } + Expr::Literal(lit) => lit.display_sql(buffer), Expr::Alias(inner, ..) => to_sql_inner(inner, buffer), Expr::BinaryOp { op, left, right } => { to_sql_inner(left, buffer)?; let op = match op { - Operator::Eq => "=".to_string(), - Operator::NotEq => "!=".to_string(), - Operator::Lt => "<".to_string(), - Operator::LtEq => "<=".to_string(), - Operator::Gt => ">".to_string(), - Operator::GtEq => ">=".to_string(), - Operator::And => "AND".to_string(), - Operator::Or => "OR".to_string(), + Operator::Eq => "=", + Operator::NotEq => "!=", + Operator::Lt => "<", + Operator::LtEq => "<=", + Operator::Gt => ">", + Operator::GtEq => ">=", + Operator::And => "AND", + Operator::Or => "OR", _ => { return Err(io::Error::new( io::ErrorKind::Other, diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index bb1f2c60af..34c6d2ce03 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -13,6 +13,7 @@ use daft_core::{ utils::display_table::{display_date32, display_series_literal, display_timestamp}, }; use serde::{Deserialize, Serialize}; +use std::io::{self, Write}; use std::{ fmt::{Display, Formatter, Result}, hash::{Hash, Hasher}, @@ -213,33 +214,37 @@ impl LiteralValue { result } - pub fn to_sql(&self) -> Option { + pub fn display_sql(&self, buffer: &mut W) -> io::Result<()> { use LiteralValue::*; + let display_sql_err = Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported literal for SQL translation", + )); match self { Null | Boolean(..) | Int32(..) | UInt32(..) | Int64(..) | UInt64(..) | Float64(..) => { - self.to_string().into() + write!(buffer, "{}", self) } - Utf8(val) => format!("'{}'", val).into(), - Date(val) => format!("DATE '{}'", display_date32(*val)).into(), + Utf8(val) => write!(buffer, "'{}'", val), + Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), Timestamp(val, tu, tz) => { // Different databases have different ways of handling timezones, so there's no reliable way to convert this to SQL. if tz.is_some() { - return None; + return display_sql_err; } // Note: Our display_timestamp function returns a string in the ISO 8601 format "YYYY-MM-DDTHH:MM:SS.fffff". // However, the ANSI SQL standard replaces the 'T' with a space. See: https://docs.actian.com/ingres/10s/index.html#page/SQLRef/Summary_of_ANSI_Date_2fTime_Data_Types.htm // While many databases support the 'T', some such as Trino, do not. So we replace the 'T' with a space here. // We also don't use the i64 unix timestamp directly, because different databases have different functions to convert unix timestamps to timestamps, e.g. Trino uses from_unixtime, while PostgreSQL uses to_timestamp. - format!( + write!( + buffer, "TIMESTAMP '{}'", display_timestamp(*val, tu, tz).replace('T', " ") ) - .into() } // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. - Decimal(..) | Series(..) | Time(..) | Binary(..) => None, + Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err, #[cfg(feature = "python")] - Python(..) => None, + Python(..) => display_sql_err, } } } diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 090c095053..a7ff7e3cf7 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -309,11 +309,6 @@ fn materialize_scan_task( .collect::>>() })?, FileFormatConfig::Database(DatabaseSourceConfig { sql }) => { - let predicate_sql = scan_task - .pushdowns - .filters - .as_ref() - .and_then(|p| p.to_sql()); let predicate_expr = scan_task .pushdowns .filters @@ -325,7 +320,6 @@ fn materialize_scan_task( py, sql, url, - predicate_sql.as_deref(), predicate_expr.clone(), scan_task.schema.clone().into(), scan_task diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 0137078e95..7ba80f7e89 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -787,12 +787,10 @@ pub(crate) fn read_parquet_into_py_table( .extract() } -#[allow(clippy::too_many_arguments)] pub(crate) fn read_sql_into_py_table( py: Python, sql: &str, url: &str, - predicate_sql: Option<&str>, predicate_expr: Option, schema: PySchema, include_columns: Option>, @@ -815,7 +813,7 @@ pub(crate) fn read_sql_into_py_table( let sql_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadSQLOptions"))? - .call1((predicate_sql, predicate_pyexpr))?; + .call1((predicate_pyexpr,))?; let read_options = py .import(pyo3::intern!(py, "daft.runners.partitioning"))? .getattr(pyo3::intern!(py, "TableReadOptions"))? diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index b0ac17b6a3..d664c2ef80 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -12,6 +12,12 @@ from tests.integration.sql.conftest import TEST_TABLE_NAME +@pytest.mark.integration() +def test_sql_show(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df.show() + + @pytest.mark.integration() def test_sql_create_dataframe_ok(test_db) -> None: df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db)