diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index af281d5f58..2ea4c1aaf5 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -491,6 +491,75 @@ jobs: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK + integration-test-sql: + runs-on: ubuntu-latest + timeout-minutes: 30 + needs: + - integration-test-build + env: + package-name: getdaft + strategy: + fail-fast: false + matrix: + python-version: ['3.8'] # can't use 3.7 due to requiring anon mode for adlfs + daft-runner: [py, ray] + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: pip + cache-dependency-path: | + pyproject.toml + requirements-dev.txt + - name: Download built wheels + uses: actions/download-artifact@v3 + with: + name: wheels + path: dist + - name: Setup Virtual Env + run: | + python -m venv venv + echo "$GITHUB_WORKSPACE/venv/bin" >> $GITHUB_PATH + - name: Install Daft and dev dependencies + run: | + pip install --upgrade pip + pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall + rm -rf daft + - name: Spin up services + run: | + pushd ./tests/integration/sql/docker-compose/ + docker-compose -f ./docker-compose.yml up -d + popd + - name: Run sql integration tests + run: | + pytest tests/integration/sql -m 'integration' --durations=50 + env: + DAFT_RUNNER: ${{ matrix.daft-runner }} + - name: Send Slack notification on failure + uses: slackapi/slack-github-action@v1.24.0 + if: ${{ failure() && (github.ref == 'refs/heads/main') }} + with: + payload: | + { + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": ":rotating_light: [CI] SQL Integration Tests <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|workflow> *FAILED on main* :rotating_light:" + } + } + ] + } + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} + SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK + rust-tests: runs-on: ${{ matrix.os }}-latest timeout-minutes: 30 diff --git a/daft/__init__.py b/daft/__init__.py index f2470e7a31..9069810c42 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -76,6 +76,7 @@ def get_build_type() -> str: read_iceberg, read_json, read_parquet, + read_sql, ) from daft.series import Series from daft.udf import udf @@ -94,6 +95,7 @@ def get_build_type() -> str: "read_parquet", "read_iceberg", "read_delta_lake", + "read_sql", "DataCatalogType", "DataCatalogTable", "DataFrame", diff --git a/daft/context.py b/daft/context.py index 48ef44df94..0a8ea18f25 100644 --- a/daft/context.py +++ b/daft/context.py @@ -254,6 +254,7 @@ def set_execution_config( csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + read_sql_partition_size_bytes: int | None = None, ) -> DaftContext: """Globally sets various configuration parameters which control various aspects of Daft execution. These configuration values are used when a Dataframe is executed (e.g. calls to `.write_*`, `.collect()` or `.show()`) @@ -283,6 +284,7 @@ def set_execution_config( csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 shuffle_aggregation_default_partitions: Minimum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. + read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB """ # Replace values in the DaftExecutionConfig with user-specified overrides ctx = get_context() @@ -302,6 +304,7 @@ def set_execution_config( csv_target_filesize=csv_target_filesize, csv_inflation_factor=csv_inflation_factor, shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions, + read_sql_partition_size_bytes=read_sql_partition_size_bytes, ) ctx._daft_execution_config = new_daft_execution_config diff --git a/daft/daft.pyi b/daft/daft.pyi index c66d34e83d..ff20ef1e2b 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -217,12 +217,21 @@ class JsonSourceConfig: chunk_size: int | None = None, ): ... +class DatabaseSourceConfig: + """ + Configuration of a database data source. + """ + + sql: str + + def __init__(self, sql: str): ... + class FileFormatConfig: """ Configuration for parsing a particular file format (Parquet, CSV, JSON). """ - config: ParquetSourceConfig | CsvSourceConfig | JsonSourceConfig + config: ParquetSourceConfig | CsvSourceConfig | JsonSourceConfig | DatabaseSourceConfig @staticmethod def from_parquet_config(config: ParquetSourceConfig) -> FileFormatConfig: @@ -242,6 +251,12 @@ class FileFormatConfig: Create a JSON file format config. """ ... + @staticmethod + def from_database_config(config: DatabaseSourceConfig) -> FileFormatConfig: + """ + Create a database file format config. + """ + ... def file_format(self) -> FileFormat: """ Get the file format for this config. @@ -583,6 +598,20 @@ class ScanTask: Create a Catalog Scan Task """ ... + @staticmethod + def sql_scan_task( + url: str, + file_format: FileFormatConfig, + schema: PySchema, + num_rows: int | None, + storage_config: StorageConfig, + size_bytes: int | None, + pushdowns: Pushdowns | None, + ) -> ScanTask: + """ + Create a SQL Scan Task + """ + ... class ScanOperatorHandle: """ @@ -800,6 +829,7 @@ class PyDataType: @staticmethod def python() -> PyDataType: ... def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pyarrow.DataType: ... + def is_numeric(self) -> builtins.bool: ... def is_image(self) -> builtins.bool: ... def is_fixed_shape_image(self) -> builtins.bool: ... def is_tensor(self) -> builtins.bool: ... @@ -826,6 +856,7 @@ class PySchema: def names(self) -> list[str]: ... def union(self, other: PySchema) -> PySchema: ... def eq(self, other: PySchema) -> bool: ... + def estimate_row_size_bytes(self) -> float: ... @staticmethod def from_field_name_and_types(names_and_types: list[tuple[str, PyDataType]]) -> PySchema: ... @staticmethod @@ -875,6 +906,7 @@ class PyExpr: def is_in(self, other: PyExpr) -> PyExpr: ... def name(self) -> str: ... def to_field(self, schema: PySchema) -> PyField: ... + def to_sql(self) -> str | None: ... def __repr__(self) -> str: ... def __hash__(self) -> int: ... def __reduce__(self) -> tuple: ... @@ -1218,6 +1250,7 @@ class PyDaftExecutionConfig: csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + read_sql_partition_size_bytes: int | None = None, ) -> PyDaftExecutionConfig: ... @property def scan_tasks_min_size_bytes(self) -> int: ... @@ -1243,6 +1276,8 @@ class PyDaftExecutionConfig: def csv_inflation_factor(self) -> float: ... @property def shuffle_aggregation_default_partitions(self) -> int: ... + @property + def read_sql_partition_size_bytes(self) -> int: ... class PyDaftPlanningConfig: def with_config_values( diff --git a/daft/datatype.py b/daft/datatype.py index 6c0cf6c1d4..6d7e25a6ce 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -375,6 +375,8 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: return cls.decimal128(arrow_type.precision, arrow_type.scale) elif pa.types.is_date32(arrow_type): return cls.date() + elif pa.types.is_date64(arrow_type): + return cls.timestamp(TimeUnit.ms()) elif pa.types.is_time64(arrow_type): timeunit = TimeUnit.from_str(pa.type_for_alias(str(arrow_type)).unit) return cls.time(timeunit) @@ -479,6 +481,9 @@ def _is_image_type(self) -> builtins.bool: def _is_fixed_shape_image_type(self) -> builtins.bool: return self._dtype.is_fixed_shape_image() + def _is_numeric_type(self) -> builtins.bool: + return self._dtype.is_numeric() + def _is_map(self) -> builtins.bool: return self._dtype.is_map() diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 89e4ca8728..09b3ad6fc6 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -473,6 +473,9 @@ def name(self) -> builtins.str: def __repr__(self) -> builtins.str: return repr(self._expr) + def _to_sql(self) -> builtins.str | None: + return self._expr.to_sql() + def _to_field(self, schema: Schema) -> Field: return Field._from_pyfield(self._expr.to_field(schema._schema)) diff --git a/daft/io/__init__.py b/daft/io/__init__.py index 18330faaf3..ce21b86cf5 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -14,6 +14,7 @@ from daft.io._iceberg import read_iceberg from daft.io._json import read_json from daft.io._parquet import read_parquet +from daft.io._sql import read_sql from daft.io.catalog import DataCatalogTable, DataCatalogType from daft.io.file_path import from_glob_path @@ -39,6 +40,7 @@ def _set_linux_cert_paths(): "read_parquet", "read_iceberg", "read_delta_lake", + "read_sql", "IOConfig", "S3Config", "AzureConfig", diff --git a/daft/io/_sql.py b/daft/io/_sql.py new file mode 100644 index 0000000000..b140e42d51 --- /dev/null +++ b/daft/io/_sql.py @@ -0,0 +1,55 @@ +# isort: dont-add-import: from __future__ import annotations + + +from typing import Optional + +from daft import context +from daft.api_annotations import PublicAPI +from daft.daft import PythonStorageConfig, ScanOperatorHandle, StorageConfig +from daft.dataframe import DataFrame +from daft.logical.builder import LogicalPlanBuilder +from daft.sql.sql_scan import SQLScanOperator + + +@PublicAPI +def read_sql( + sql: str, url: str, partition_col: Optional[str] = None, num_partitions: Optional[int] = None +) -> DataFrame: + """Creates a DataFrame from a SQL query. + + Example: + >>> df = daft.read_sql("SELECT * FROM my_table", "sqlite:///my_database.db") + + .. NOTE:: + If partition_col is specified, this function will partition the query by the specified column. You may specify the number of partitions, or let Daft determine the number of partitions. + Daft will first calculate percentiles of the specified column. For example if num_partitions is 3, Daft will calculate the 33rd and 66th percentiles of the specified column, and use these values to partition the query. + If the database does not support the necessary SQL syntax to calculate percentiles, Daft will calculate the min and max of the specified column and partition the query into equal ranges. + + Args: + sql (str): SQL query to execute + url (str): URL to the database + partition_col (Optional[str]): Column to partition the data by, defaults to None + num_partitions (Optional[int]): Number of partitions to read the data into, + defaults to None, which will lets Daft determine the number of partitions. + + Returns: + DataFrame: Dataframe containing the results of the query + """ + + if num_partitions is not None and partition_col is None: + raise ValueError("Failed to execute sql: partition_col must be specified when num_partitions is specified") + + io_config = context.get_context().daft_planning_config.default_io_config + storage_config = StorageConfig.python(PythonStorageConfig(io_config)) + + sql_operator = SQLScanOperator( + sql, + url, + storage_config, + partition_col=partition_col, + num_partitions=num_partitions, + ) + handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) + builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) + + return DataFrame(builder) diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 2655885bb9..b9334acce8 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -116,6 +116,9 @@ def __len__(self) -> int: def column_names(self) -> list[str]: return list(self._schema.names()) + def estimate_row_size_bytes(self) -> float: + return self._schema.estimate_row_size_bytes() + def __iter__(self) -> Iterator[Field]: col_names = self.column_names() yield from (self[name] for name in col_names) diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 56fda08a3a..35f3f0fb6d 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -72,6 +72,17 @@ class TableParseParquetOptions: coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns() +@dataclass(frozen=True) +class TableReadSQLOptions: + """Options for parsing SQL tables + + Args: + predicate_expression: Expression predicate to apply to the table + """ + + predicate_expression: Expression | None = None + + @dataclass(frozen=True) class PartialPartitionMetadata: num_rows: None | int diff --git a/daft/sql/__init__.py b/daft/sql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/daft/sql/sql_reader.py b/daft/sql/sql_reader.py new file mode 100644 index 0000000000..367fdcaeaf --- /dev/null +++ b/daft/sql/sql_reader.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import logging +import warnings +from urllib.parse import urlparse + +import pyarrow as pa + +logger = logging.getLogger(__name__) + + +class SQLReader: + def __init__( + self, + sql: str, + url: str, + limit: int | None = None, + projection: list[str] | None = None, + predicate: str | None = None, + ) -> None: + + self.sql = sql + self.url = url + self.limit = limit + self.projection = projection + self.predicate = predicate + + def read(self) -> pa.Table: + try: + sql = self._construct_sql_query(apply_limit=True) + return self._execute_sql_query(sql) + except RuntimeError: + warnings.warn("Failed to execute the query with a limit, attempting to read the entire table.") + sql = self._construct_sql_query(apply_limit=False) + return self._execute_sql_query(sql) + + def _construct_sql_query(self, apply_limit: bool) -> str: + clauses = [] + if self.projection is not None: + clauses.append(f"SELECT {', '.join(self.projection)}") + else: + clauses.append("SELECT *") + + clauses.append(f"FROM ({self.sql}) AS subquery") + + if self.predicate is not None: + clauses.append(f"WHERE {self.predicate}") + + if self.limit is not None and apply_limit is True: + clauses.append(f"LIMIT {self.limit}") + + return "\n".join(clauses) + + def _execute_sql_query(self, sql: str) -> pa.Table: + # Supported DBs extracted from here https://github.com/sfu-db/connector-x/tree/7b3147436b7e20b96691348143d605e2249d6119?tab=readme-ov-file#sources + supported_dbs = { + "postgres", + "postgresql", + "mysql", + "mssql", + "oracle", + "bigquery", + "sqlite", + "clickhouse", + "redshift", + } + + db_type = urlparse(self.url).scheme + db_type = db_type.strip().lower() + + # Check if the database type is supported + if db_type in supported_dbs: + return self._execute_sql_query_with_connectorx(sql) + else: + return self._execute_sql_query_with_sqlalchemy(sql) + + def _execute_sql_query_with_connectorx(self, sql: str) -> pa.Table: + import connectorx as cx + + logger.info(f"Using connectorx to execute sql: {sql}") + try: + table = cx.read_sql(conn=self.url, query=sql, return_type="arrow") + return table + except Exception as e: + raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e + + def _execute_sql_query_with_sqlalchemy(self, sql: str) -> pa.Table: + from sqlalchemy import create_engine, text + + logger.info(f"Using sqlalchemy to execute sql: {sql}") + try: + with create_engine(self.url).connect() as connection: + result = connection.execute(text(sql)) + rows = result.fetchall() + pydict = {column_name: [row[i] for row in rows] for i, column_name in enumerate(result.keys())} + + # TODO: Use type codes from cursor description to create pyarrow schema + return pa.Table.from_pydict(pydict) + except Exception as e: + raise RuntimeError(f"Failed to execute sql: {sql} with url: {self.url}, error: {e}") from e diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py new file mode 100644 index 0000000000..93e1b0ab56 --- /dev/null +++ b/daft/sql/sql_scan.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import logging +import math +import warnings +from collections.abc import Iterator +from enum import Enum, auto +from typing import Any + +from daft.context import get_context +from daft.daft import ( + DatabaseSourceConfig, + FileFormatConfig, + Pushdowns, + ScanTask, + StorageConfig, +) +from daft.expressions.expressions import lit +from daft.io.scan import PartitionField, ScanOperator +from daft.logical.schema import Schema +from daft.sql.sql_reader import SQLReader + +logger = logging.getLogger(__name__) + + +class PartitionBoundStrategy(Enum): + PERCENTILE = auto() + MIN_MAX = auto() + + +class SQLScanOperator(ScanOperator): + def __init__( + self, + sql: str, + url: str, + storage_config: StorageConfig, + partition_col: str | None = None, + num_partitions: int | None = None, + ) -> None: + super().__init__() + self.sql = sql + self.url = url + self.storage_config = storage_config + self._partition_col = partition_col + self._num_partitions = num_partitions + self._schema = self._attempt_schema_read() + + def schema(self) -> Schema: + return self._schema + + def display_name(self) -> str: + return f"SQLScanOperator(sql={self.sql}, url={self.url})" + + def partitioning_keys(self) -> list[PartitionField]: + return [] + + def multiline_display(self) -> list[str]: + return [ + self.display_name(), + f"Schema = {self._schema}", + ] + + def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: + total_rows = self._get_num_rows() + estimate_row_size_bytes = self.schema().estimate_row_size_bytes() + total_size = total_rows * estimate_row_size_bytes + num_scan_tasks = ( + math.ceil(total_size / get_context().daft_execution_config.read_sql_partition_size_bytes) + if self._num_partitions is None + else self._num_partitions + ) + + if num_scan_tasks == 1 or self._partition_col is None: + return self._single_scan_task(pushdowns, total_rows, total_size) + + partition_bounds, strategy = self._get_partition_bounds_and_strategy(num_scan_tasks) + partition_bounds = [lit(bound)._to_sql() for bound in partition_bounds] + + if any(bound is None for bound in partition_bounds): + warnings.warn("Unable to partion the data using the specified column. Falling back to a single scan task.") + return self._single_scan_task(pushdowns, total_rows, total_size) + + size_bytes = math.ceil(total_size / num_scan_tasks) if strategy == PartitionBoundStrategy.PERCENTILE else None + scan_tasks = [] + for i in range(num_scan_tasks): + if i == 0: + sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} <= {partition_bounds[i]}" + elif i == num_scan_tasks - 1: + sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} > {partition_bounds[i - 1]}" + else: + sql = f"SELECT * FROM ({self.sql}) AS subquery WHERE {self._partition_col} > {partition_bounds[i - 1]} AND {self._partition_col} <= {partition_bounds[i]}" + + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(sql=sql)) + + scan_tasks.append( + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + num_rows=None, + storage_config=self.storage_config, + size_bytes=size_bytes, + pushdowns=pushdowns, + ) + ) + + return iter(scan_tasks) + + def can_absorb_filter(self) -> bool: + return False + + def can_absorb_limit(self) -> bool: + return False + + def can_absorb_select(self) -> bool: + return False + + def _attempt_schema_read(self) -> Schema: + pa_table = SQLReader(self.sql, self.url, limit=1).read() + schema = Schema.from_pyarrow_schema(pa_table.schema) + return schema + + def _get_num_rows(self) -> int: + pa_table = SQLReader( + self.sql, + self.url, + projection=["COUNT(*)"], + ).read() + + if pa_table.num_rows != 1: + raise RuntimeError( + "Failed to get the number of rows: COUNT(*) query returned an unexpected number of rows." + ) + if pa_table.num_columns != 1: + raise RuntimeError( + "Failed to get the number of rows: COUNT(*) query returned an unexpected number of columns." + ) + + return pa_table.column(0)[0].as_py() + + def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: + try: + # Try to get percentiles using percentile_cont + percentiles = [i / num_scan_tasks for i in range(1, num_scan_tasks)] + pa_table = SQLReader( + self.sql, + self.url, + projection=[ + f"percentile_cont({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" + for i, percentile in enumerate(percentiles) + ], + ).read() + return pa_table, PartitionBoundStrategy.PERCENTILE + + except RuntimeError as e: + # If percentiles fails, use the min and max of the partition column + logger.info("Failed to get percentiles using percentile_cont, falling back to min and max. Error: %s", e) + + pa_table = SQLReader( + self.sql, + self.url, + projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"], + ).read() + return pa_table, PartitionBoundStrategy.MIN_MAX + + def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]: + if self._partition_col is None: + raise ValueError("Failed to get partition bounds: partition_col must be specified to partition the data.") + + if not ( + self._schema[self._partition_col].dtype._is_temporal_type() + or self._schema[self._partition_col].dtype._is_numeric_type() + ): + raise ValueError( + f"Failed to get partition bounds: {self._partition_col} is not a numeric or temporal type, and cannot be used for partitioning." + ) + + pa_table, strategy = self._attempt_partition_bounds_read(num_scan_tasks) + + if pa_table.num_rows != 1: + raise RuntimeError(f"Failed to get partition bounds: expected 1 row, but got {pa_table.num_rows}.") + + if strategy == PartitionBoundStrategy.PERCENTILE: + if pa_table.num_columns != num_scan_tasks - 1: + raise RuntimeError( + f"Failed to get partition bounds: expected {num_scan_tasks - 1} percentiles, but got {pa_table.num_columns}." + ) + + bounds = [pa_table.column(i)[0].as_py() for i in range(num_scan_tasks - 1)] + + elif strategy == PartitionBoundStrategy.MIN_MAX: + if pa_table.num_columns != 2: + raise RuntimeError( + f"Failed to get partition bounds: expected 2 columns, but got {pa_table.num_columns}." + ) + + min_val = pa_table.column(0)[0].as_py() + max_val = pa_table.column(1)[0].as_py() + range_size = (max_val - min_val) / num_scan_tasks + bounds = [min_val + range_size * i for i in range(1, num_scan_tasks)] + + return bounds, strategy + + def _single_scan_task(self, pushdowns: Pushdowns, total_rows: int | None, total_size: float) -> Iterator[ScanTask]: + file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) + return iter( + [ + ScanTask.sql_scan_task( + url=self.url, + file_format=file_format_config, + schema=self._schema._schema, + num_rows=total_rows, + storage_config=self.storage_config, + size_bytes=math.ceil(total_size), + pushdowns=pushdowns, + ) + ] + ) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 4097b58dce..5eb0fcb332 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -39,8 +39,10 @@ TableParseCSVOptions, TableParseParquetOptions, TableReadOptions, + TableReadSQLOptions, ) from daft.series import Series +from daft.sql.sql_reader import SQLReader from daft.table import MicroPartition FileInput = Union[pathlib.Path, str, IO[bytes]] @@ -73,7 +75,6 @@ def _cast_table_to_schema(table: MicroPartition, read_options: TableReadOptions, their corresponding dtype in `schema`, and column ordering/inclusion matches `read_options.column_names` (if provided). """ pruned_schema = schema - # If reading only a subset of fields, prune the schema if read_options.column_names is not None: pruned_schema = Schema._from_fields([schema[name] for name in read_options.column_names]) @@ -212,6 +213,55 @@ def read_parquet( return _cast_table_to_schema(MicroPartition.from_arrow(table), read_options=read_options, schema=schema) +def read_sql( + sql: str, + url: str, + schema: Schema, + sql_options: TableReadSQLOptions = TableReadSQLOptions(), + read_options: TableReadOptions = TableReadOptions(), +) -> MicroPartition: + """Reads a MicroPartition from a SQL query + + Args: + sql (str): SQL query to execute + url (str): URL to the database + schema (Schema): Daft schema to read the SQL query into + sql_options (TableReadSQLOptions, optional): SQL-specific configs to apply when reading the file + read_options (TableReadOptions, optional): Options for reading the file + + Returns: + MicroPartition: MicroPartition from SQL query + """ + + if sql_options.predicate_expression is not None: + # If the predicate can be translated to SQL, we can apply all pushdowns to the SQL query + predicate_sql = sql_options.predicate_expression._to_sql() + apply_pushdowns_to_sql = predicate_sql is not None + else: + # If we don't have a predicate, we can still apply the limit and projection to the SQL query + predicate_sql = None + apply_pushdowns_to_sql = True + + pa_table = SQLReader( + sql, + url, + limit=read_options.num_rows if apply_pushdowns_to_sql else None, + projection=read_options.column_names if apply_pushdowns_to_sql else None, + predicate=predicate_sql, + ).read() + mp = MicroPartition.from_arrow(pa_table) + + if len(mp) != 0 and apply_pushdowns_to_sql is False: + # If we have a non-empty table and we didn't apply pushdowns to SQL, we need to apply them in-memory + if sql_options.predicate_expression is not None: + mp = mp.filter(ExpressionsProjection([sql_options.predicate_expression])) + + if read_options.num_rows is not None: + mp = mp.head(read_options.num_rows) + + return _cast_table_to_schema(mp, read_options=read_options, schema=schema) + + class PACSVStreamHelper: def __init__(self, stream: pa.CSVStreamReader) -> None: self.stream = stream diff --git a/docs/source/api_docs/creation.rst b/docs/source/api_docs/creation.rst index d83fa84fcd..b02b3ee3a0 100644 --- a/docs/source/api_docs/creation.rst +++ b/docs/source/api_docs/creation.rst @@ -20,6 +20,24 @@ Python Objects from_pylist from_pydict +Arrow +~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/io_functions + + from_arrow + +Pandas +~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/io_functions + + from_pandas + Files ----- @@ -54,6 +72,15 @@ JSON read_json +File Paths +~~~~~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/io_functions + + from_glob_path + Data Catalogs ------------- @@ -75,40 +102,6 @@ Delta Lake read_delta_lake -In-memory ---------- - -Arrow -~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc_gen/io_functions - - -.. autosummary:: - :nosignatures: - :toctree: doc_gen/io_functions - - from_arrow - -Pandas -~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc_gen/io_functions - - from_pandas - -File Paths -~~~~~~~~~~ - -.. autosummary:: - :nosignatures: - :toctree: doc_gen/io_functions - - from_glob_path Integrations ------------ @@ -132,3 +125,12 @@ Dask :toctree: doc_gen/io_functions from_dask_dataframe + +Databases +~~~~~~~~~ + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/io_functions + + read_sql \ No newline at end of file diff --git a/docs/source/user_guide/basic_concepts/read-and-write.rst b/docs/source/user_guide/basic_concepts/read-and-write.rst index e4528cd5d8..2eb04cde22 100644 --- a/docs/source/user_guide/basic_concepts/read-and-write.rst +++ b/docs/source/user_guide/basic_concepts/read-and-write.rst @@ -72,6 +72,23 @@ For testing, or small datasets that fit in memory, you may also create DataFrame To learn more, consult the API documentation on :ref:`creating DataFrames from in-memory data structures `. +From Databases +^^^^^^^^^^^^^^ + +Daft can also read data from a variety of databases, including PostgreSQL, MySQL, Trino, and SQLite using the :func:`daft.read_sql` method. +In order to partition the data, you can specify a partition column, which will allow Daft to read the data in parallel. + +.. code:: python + + # Read from a PostgreSQL database + uri = "postgresql://user:password@host:port/database" + df = daft.read_sql(uri, "SELECT * FROM my_table") + + # Read with a partition column + df = daft.read_sql(uri, "SELECT * FROM my_table", partition_col="date") + +To learn more, consult the API documentation on :func:`daft.read_sql`. + Writing Data ------------ diff --git a/pyproject.toml b/pyproject.toml index 965dbf5a14..fe21174dd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ readme = "README.rst" requires-python = ">=3.7" [project.optional-dependencies] -all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake]"] +all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake, sql]"] aws = ["boto3"] azure = [] deltalake = ["deltalake"] @@ -39,6 +39,7 @@ ray = [ # Explicitly install packaging. See issue: https://github.com/ray-project/ray/issues/34806 "packaging" ] +sql = ["connectorx", "sqlalchemy", "trino[sqlalchemy]"] viz = [] [project.urls] diff --git a/requirements-dev.txt b/requirements-dev.txt index 42f49446c8..a2f8a44b77 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -50,6 +50,13 @@ deltalake==0.13.0; platform_system != "Windows" and python_version < '3.8' # Databricks databricks-sdk==0.12.0 +#SQL +sqlalchemy==2.0.25; python_version >= '3.8' +connectorx==0.3.2; python_version >= '3.8' +trino[sqlalchemy]==0.328.0; python_version >= '3.8' +PyMySQL==1.1.0; python_version >= '3.8' +psycopg2==2.9.9; python_version >= '3.8' + # AWS s3fs==2023.1.0; python_version < '3.8' s3fs==2023.12.0; python_version >= '3.8' diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index a3dd531d3a..ac53ad30a5 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -35,6 +35,7 @@ pub struct DaftExecutionConfig { pub csv_target_filesize: usize, pub csv_inflation_factor: f64, pub shuffle_aggregation_default_partitions: usize, + pub read_sql_partition_size_bytes: usize, } impl Default for DaftExecutionConfig { @@ -53,6 +54,7 @@ impl Default for DaftExecutionConfig { csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, shuffle_aggregation_default_partitions: 200, + read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB } } } diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 8d2124c7ba..9071c878d1 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -90,6 +90,7 @@ impl PyDaftExecutionConfig { csv_target_filesize: Option, csv_inflation_factor: Option, shuffle_aggregation_default_partitions: Option, + read_sql_partition_size_bytes: Option, ) -> PyResult { let mut config = self.config.as_ref().clone(); @@ -136,6 +137,9 @@ impl PyDaftExecutionConfig { { config.shuffle_aggregation_default_partitions = shuffle_aggregation_default_partitions; } + if let Some(read_sql_partition_size_bytes) = read_sql_partition_size_bytes { + config.read_sql_partition_size_bytes = read_sql_partition_size_bytes; + } Ok(PyDaftExecutionConfig { config: Arc::new(config), @@ -202,6 +206,11 @@ impl PyDaftExecutionConfig { Ok(self.config.shuffle_aggregation_default_partitions) } + #[getter] + fn get_read_sql_partition_size_bytes(&self) -> PyResult { + Ok(self.config.read_sql_partition_size_bytes) + } + fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec,))> { let bin_data = bincode::serialize(self.config.as_ref()) .expect("DaftExecutionConfig should be serializable to bytes"); diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index 4e4fa3de87..598f4b1625 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -343,6 +343,10 @@ impl PyDataType { } } + pub fn is_numeric(&self) -> PyResult { + Ok(self.dtype.is_numeric()) + } + pub fn is_image(&self) -> PyResult { Ok(self.dtype.is_image()) } diff --git a/src/daft-core/src/python/schema.rs b/src/daft-core/src/python/schema.rs index 975aeffc5c..74c9dc1e81 100644 --- a/src/daft-core/src/python/schema.rs +++ b/src/daft-core/src/python/schema.rs @@ -37,6 +37,10 @@ impl PySchema { Ok(self.schema.fields.eq(&other.schema.fields)) } + pub fn estimate_row_size_bytes(&self) -> PyResult { + Ok(self.schema.estimate_row_size_bytes()) + } + #[staticmethod] pub fn from_field_name_and_types( names_and_types: Vec<(String, PyDataType)>, diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index d66ce7df55..6ad61c9434 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -17,6 +17,7 @@ use common_error::{DaftError, DaftResult}; use serde::{Deserialize, Serialize}; use std::{ fmt::{Debug, Display, Formatter, Result}, + io::{self, Write}, sync::Arc, }; @@ -568,6 +569,77 @@ impl Expr { _ => None, } } + + pub fn to_sql(&self) -> Option { + fn to_sql_inner(expr: &Expr, buffer: &mut W) -> io::Result<()> { + match expr { + Expr::Column(name) => write!(buffer, "{}", name), + Expr::Literal(lit) => lit.display_sql(buffer), + Expr::Alias(inner, ..) => to_sql_inner(inner, buffer), + Expr::BinaryOp { op, left, right } => { + to_sql_inner(left, buffer)?; + let op = match op { + Operator::Eq => "=", + Operator::NotEq => "!=", + Operator::Lt => "<", + Operator::LtEq => "<=", + Operator::Gt => ">", + Operator::GtEq => ">=", + Operator::And => "AND", + Operator::Or => "OR", + _ => { + return Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported operator for SQL translation", + )) + } + }; + write!(buffer, " {} ", op)?; + to_sql_inner(right, buffer) + } + Expr::Not(inner) => { + write!(buffer, "NOT (")?; + to_sql_inner(inner, buffer)?; + write!(buffer, ")") + } + Expr::IsNull(inner) => { + write!(buffer, "(")?; + to_sql_inner(inner, buffer)?; + write!(buffer, ") IS NULL") + } + Expr::NotNull(inner) => { + write!(buffer, "(")?; + to_sql_inner(inner, buffer)?; + write!(buffer, ") IS NOT NULL") + } + Expr::IfElse { + if_true, + if_false, + predicate, + } => { + write!(buffer, "CASE WHEN ")?; + to_sql_inner(predicate, buffer)?; + write!(buffer, " THEN ")?; + to_sql_inner(if_true, buffer)?; + write!(buffer, " ELSE ")?; + to_sql_inner(if_false, buffer)?; + write!(buffer, " END") + } + // TODO: Implement SQL translations for these expressions if possible + Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) | Expr::Function { .. } => { + Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported expression for SQL translation", + )) + } + } + } + + let mut buffer = Vec::new(); + to_sql_inner(self, &mut buffer) + .ok() + .and_then(|_| String::from_utf8(buffer).ok()) + } } impl Display for Expr { diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 93be85dacb..34c6d2ce03 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -13,6 +13,7 @@ use daft_core::{ utils::display_table::{display_date32, display_series_literal, display_timestamp}, }; use serde::{Deserialize, Serialize}; +use std::io::{self, Write}; use std::{ fmt::{Display, Formatter, Result}, hash::{Hash, Hasher}, @@ -212,6 +213,40 @@ impl LiteralValue { }; result } + + pub fn display_sql(&self, buffer: &mut W) -> io::Result<()> { + use LiteralValue::*; + let display_sql_err = Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported literal for SQL translation", + )); + match self { + Null | Boolean(..) | Int32(..) | UInt32(..) | Int64(..) | UInt64(..) | Float64(..) => { + write!(buffer, "{}", self) + } + Utf8(val) => write!(buffer, "'{}'", val), + Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), + Timestamp(val, tu, tz) => { + // Different databases have different ways of handling timezones, so there's no reliable way to convert this to SQL. + if tz.is_some() { + return display_sql_err; + } + // Note: Our display_timestamp function returns a string in the ISO 8601 format "YYYY-MM-DDTHH:MM:SS.fffff". + // However, the ANSI SQL standard replaces the 'T' with a space. See: https://docs.actian.com/ingres/10s/index.html#page/SQLRef/Summary_of_ANSI_Date_2fTime_Data_Types.htm + // While many databases support the 'T', some such as Trino, do not. So we replace the 'T' with a space here. + // We also don't use the i64 unix timestamp directly, because different databases have different functions to convert unix timestamps to timestamps, e.g. Trino uses from_unixtime, while PostgreSQL uses to_timestamp. + write!( + buffer, + "TIMESTAMP '{}'", + display_timestamp(*val, tu, tz).replace('T', " ") + ) + } + // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. + Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err, + #[cfg(feature = "python")] + Python(..) => display_sql_err, + } + } } pub trait Literal { diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a4e8c7091a..d23e3def34 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -324,6 +324,10 @@ impl PyExpr { Ok(self.expr.name()?) } + pub fn to_sql(&self) -> PyResult> { + Ok(self.expr.to_sql()) + } + pub fn to_field(&self, schema: &PySchema) -> PyResult { Ok(self.expr.to_field(&schema.schema)?.into()) } diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 654fe2695f..a7ff7e3cf7 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -15,7 +15,9 @@ use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; use daft_parquet::read::{ read_parquet_bulk, read_parquet_metadata_bulk, ParquetSchemaInferenceOptions, }; -use daft_scan::file_format::{CsvSourceConfig, FileFormatConfig, ParquetSourceConfig}; +use daft_scan::file_format::{ + CsvSourceConfig, DatabaseSourceConfig, FileFormatConfig, ParquetSourceConfig, +}; use daft_scan::storage_config::{NativeStorageConfig, StorageConfig}; use daft_scan::{ChunkSpec, DataFileSource, Pushdowns, ScanTask}; use daft_table::Table; @@ -225,6 +227,12 @@ fn materialize_scan_task( ) .context(DaftCoreComputeSnafu)? } + FileFormatConfig::Database(_) => { + return Err(common_error::DaftError::TypeError( + "Native reads for Database file format not implemented".to_string(), + )) + .context(DaftCoreComputeSnafu); + } } } #[cfg(feature = "python")] @@ -300,6 +308,33 @@ fn materialize_scan_task( }) .collect::>>() })?, + FileFormatConfig::Database(DatabaseSourceConfig { sql }) => { + let predicate_expr = scan_task + .pushdowns + .filters + .as_ref() + .map(|p| (*p.as_ref()).clone().into()); + Python::with_gil(|py| { + urls.map(|url| { + crate::python::read_sql_into_py_table( + py, + sql, + url, + predicate_expr.clone(), + scan_task.schema.clone().into(), + scan_task + .pushdowns + .columns + .as_ref() + .map(|cols| cols.as_ref().clone()), + scan_task.pushdowns.limit, + ) + .map(|t| t.into()) + .context(PyIOSnafu) + }) + .collect::>>() + })? + } } } }; diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 2a7c487f27..7ba80f7e89 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -787,6 +787,46 @@ pub(crate) fn read_parquet_into_py_table( .extract() } +pub(crate) fn read_sql_into_py_table( + py: Python, + sql: &str, + url: &str, + predicate_expr: Option, + schema: PySchema, + include_columns: Option>, + num_rows: Option, +) -> PyResult { + let py_schema = py + .import(pyo3::intern!(py, "daft.logical.schema"))? + .getattr(pyo3::intern!(py, "Schema"))? + .getattr(pyo3::intern!(py, "_from_pyschema"))? + .call1((schema,))?; + let predicate_pyexpr = match predicate_expr { + Some(p) => Some( + py.import(pyo3::intern!(py, "daft.expressions.expressions"))? + .getattr(pyo3::intern!(py, "Expression"))? + .getattr(pyo3::intern!(py, "_from_pyexpr"))? + .call1((p,))?, + ), + None => None, + }; + let sql_options = py + .import(pyo3::intern!(py, "daft.runners.partitioning"))? + .getattr(pyo3::intern!(py, "TableReadSQLOptions"))? + .call1((predicate_pyexpr,))?; + let read_options = py + .import(pyo3::intern!(py, "daft.runners.partitioning"))? + .getattr(pyo3::intern!(py, "TableReadOptions"))? + .call1((num_rows, include_columns))?; + py.import(pyo3::intern!(py, "daft.table.table_io"))? + .getattr(pyo3::intern!(py, "read_sql"))? + .call1((sql, url, py_schema, sql_options, read_options))? + .getattr(pyo3::intern!(py, "to_table"))? + .call0()? + .getattr(pyo3::intern!(py, "_table"))? + .extract() +} + impl From for PyMicroPartition { fn from(value: MicroPartition) -> Self { PyMicroPartition { diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 74bc2ccc7a..4c02484b7b 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -40,12 +40,15 @@ use { #[cfg(feature = "python")] pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { + use daft_scan::file_format::DatabaseSourceConfig; + parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 2f8ae4cc15..9c7b667bd9 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -715,6 +715,9 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe input_physical.into(), ))) } + FileFormat::Database => Err(common_error::DaftError::ValueError( + "Database sink not yet implemented".to_string(), + )), } } } diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index bf9d0175f1..d680f89250 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -19,6 +19,7 @@ pub enum FileFormat { Parquet, Csv, Json, + Database, } impl FromStr for FileFormat { @@ -33,6 +34,8 @@ impl FromStr for FileFormat { Ok(Csv) } else if file_format.trim().eq_ignore_ascii_case("json") { Ok(Json) + } else if file_format.trim().eq_ignore_ascii_case("database") { + Ok(Database) } else { Err(DaftError::TypeError(format!( "FileFormat {} not supported!", @@ -50,6 +53,7 @@ impl From<&FileFormatConfig> for FileFormat { FileFormatConfig::Parquet(_) => Self::Parquet, FileFormatConfig::Csv(_) => Self::Csv, FileFormatConfig::Json(_) => Self::Json, + FileFormatConfig::Database(_) => Self::Database, } } } @@ -60,6 +64,7 @@ pub enum FileFormatConfig { Parquet(ParquetSourceConfig), Csv(CsvSourceConfig), Json(JsonSourceConfig), + Database(DatabaseSourceConfig), } impl FileFormatConfig { @@ -70,6 +75,7 @@ impl FileFormatConfig { Parquet(_) => "Parquet", Csv(_) => "Csv", Json(_) => "Json", + Database(_) => "Database", } } @@ -78,6 +84,7 @@ impl FileFormatConfig { Self::Parquet(source) => source.multiline_display(), Self::Csv(source) => source.multiline_display(), Self::Json(source) => source.multiline_display(), + Self::Database(source) => source.multiline_display(), } } } @@ -251,6 +258,37 @@ impl JsonSourceConfig { impl_bincode_py_state_serialization!(JsonSourceConfig); +/// Configuration for a Database data source. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] +pub struct DatabaseSourceConfig { + pub sql: String, +} + +impl DatabaseSourceConfig { + pub fn new_internal(sql: String) -> Self { + Self { sql } + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push(format!("SQL = {}", self.sql)); + res + } +} + +#[cfg(feature = "python")] +#[pymethods] +impl DatabaseSourceConfig { + /// Create a config for a Database data source. + #[new] + fn new(sql: &str) -> Self { + Self::new_internal(sql.to_string()) + } +} + +impl_bincode_py_state_serialization!(DatabaseSourceConfig); + /// Configuration for parsing a particular file format. #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(transparent)] @@ -281,6 +319,12 @@ impl PyFileFormatConfig { Self(Arc::new(FileFormatConfig::Json(config))) } + /// Create a Database file format config. + #[staticmethod] + fn from_database_config(config: DatabaseSourceConfig) -> Self { + Self(Arc::new(FileFormatConfig::Database(config))) + } + /// Get the underlying data source config. #[getter] fn get_config(&self, py: Python) -> PyObject { @@ -290,6 +334,7 @@ impl PyFileFormatConfig { Parquet(config) => config.clone().into_py(py), Csv(config) => config.clone().into_py(py), Json(config) => config.clone().into_py(py), + Database(config) => config.clone().into_py(py), } } diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 26e9773df8..426fe289c8 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -201,6 +201,11 @@ impl GlobScanOperator { io_client, Some(io_stats), )?, + FileFormatConfig::Database(_) => { + return Err(DaftError::ValueError( + "Cannot glob a database source".to_string(), + )) + } }; let schema = match schema_hint { diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 1f7023a2b8..3353813f95 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -135,32 +135,45 @@ pub enum DataFileSource { partition_spec: PartitionSpec, statistics: Option, }, + DatabaseDataSource { + path: String, + chunk_spec: Option, + size_bytes: Option, + metadata: Option, + partition_spec: Option, + statistics: Option, + }, } impl DataFileSource { pub fn get_path(&self) -> &str { match self { - Self::AnonymousDataFile { path, .. } | Self::CatalogDataFile { path, .. } => path, + Self::AnonymousDataFile { path, .. } + | Self::CatalogDataFile { path, .. } + | Self::DatabaseDataSource { path, .. } => path, } } pub fn get_chunk_spec(&self) -> Option<&ChunkSpec> { match self { Self::AnonymousDataFile { chunk_spec, .. } - | Self::CatalogDataFile { chunk_spec, .. } => chunk_spec.as_ref(), + | Self::CatalogDataFile { chunk_spec, .. } + | Self::DatabaseDataSource { chunk_spec, .. } => chunk_spec.as_ref(), } } pub fn get_size_bytes(&self) -> Option { match self { Self::AnonymousDataFile { size_bytes, .. } - | Self::CatalogDataFile { size_bytes, .. } => *size_bytes, + | Self::CatalogDataFile { size_bytes, .. } + | Self::DatabaseDataSource { size_bytes, .. } => *size_bytes, } } pub fn get_metadata(&self) -> Option<&TableMetadata> { match self { - Self::AnonymousDataFile { metadata, .. } => metadata.as_ref(), + Self::AnonymousDataFile { metadata, .. } + | Self::DatabaseDataSource { metadata, .. } => metadata.as_ref(), Self::CatalogDataFile { metadata, .. } => Some(metadata), } } @@ -168,13 +181,15 @@ impl DataFileSource { pub fn get_statistics(&self) -> Option<&TableStatistics> { match self { Self::AnonymousDataFile { statistics, .. } - | Self::CatalogDataFile { statistics, .. } => statistics.as_ref(), + | Self::CatalogDataFile { statistics, .. } + | Self::DatabaseDataSource { statistics, .. } => statistics.as_ref(), } } pub fn get_partition_spec(&self) -> Option<&PartitionSpec> { match self { - Self::AnonymousDataFile { partition_spec, .. } => partition_spec.as_ref(), + Self::AnonymousDataFile { partition_spec, .. } + | Self::DatabaseDataSource { partition_spec, .. } => partition_spec.as_ref(), Self::CatalogDataFile { partition_spec, .. } => Some(partition_spec), } } @@ -189,6 +204,14 @@ impl DataFileSource { metadata, partition_spec, statistics, + } + | Self::DatabaseDataSource { + path, + chunk_spec, + size_bytes, + metadata, + partition_spec, + statistics, } => { res.push(format!("Path = {}", path)); if let Some(chunk_spec) = chunk_spec { @@ -428,6 +451,7 @@ impl ScanTask { FileFormatConfig::Csv(_) | FileFormatConfig::Json(_) => { config.csv_inflation_factor } + FileFormatConfig::Database(_) => 0.0, }; // estimate number of rows from read schema diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 1c28639905..8d68616386 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -319,6 +319,35 @@ pub mod pylib { Ok(Some(PyScanTask(scan_task.into()))) } + #[staticmethod] + pub fn sql_scan_task( + url: String, + file_format: PyFileFormatConfig, + schema: PySchema, + storage_config: PyStorageConfig, + num_rows: Option, + size_bytes: Option, + pushdowns: Option, + ) -> PyResult { + let data_source = DataFileSource::DatabaseDataSource { + path: url, + chunk_spec: None, + size_bytes, + metadata: num_rows.map(|n| TableMetadata { length: n as usize }), + partition_spec: None, + statistics: None, + }; + + let scan_task = ScanTask::new( + vec![data_source], + file_format.into(), + schema.schema, + storage_config.into(), + pushdowns.map(|p| p.0.as_ref().clone()).unwrap_or_default(), + ); + Ok(PyScanTask(scan_task.into())) + } + pub fn __repr__(&self) -> PyResult { Ok(format!("{:?}", self.0)) } diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index fe3d470a2f..b8662219a9 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -184,7 +184,7 @@ pub fn split_by_row_groups( chunk_spec, size_bytes, .. - } => { + } | DataFileSource::DatabaseDataSource { chunk_spec, size_bytes, .. } => { *chunk_spec = Some(ChunkSpec::Parquet(curr_row_groups)); *size_bytes = Some(curr_size_bytes as u64); diff --git a/tests/integration/sql/__init__.py b/tests/integration/sql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py new file mode 100644 index 0000000000..46dc3e36ed --- /dev/null +++ b/tests/integration/sql/conftest.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import Generator + +import numpy as np +import pandas as pd +import pytest +import tenacity +from sqlalchemy import ( + Boolean, + Column, + Date, + DateTime, + Engine, + Float, + Integer, + MetaData, + String, + Table, + create_engine, + text, +) + +URLS = [ + "trino://user@localhost:8080/memory/default", + "postgresql://username:password@localhost:5432/postgres", + "mysql+pymysql://username:password@localhost:3306/mysql", +] +TEST_TABLE_NAME = "example" + + +@pytest.fixture(scope="session", params=[{"num_rows": 200}]) +def generated_data(request: pytest.FixtureRequest) -> pd.DataFrame: + num_rows = request.param["num_rows"] + + data = { + "id": np.arange(num_rows), + "float_col": np.arange(num_rows, dtype=float), + "string_col": [f"row_{i}" for i in range(num_rows)], + "bool_col": [True for _ in range(num_rows // 2)] + [False for _ in range(num_rows // 2)], + "date_col": [date(2021, 1, 1) + timedelta(days=i) for i in range(num_rows)], + "date_time_col": [datetime(2020, 1, 1, 10, 0, 0) + timedelta(hours=i) for i in range(num_rows)], + "time_col": [ + (datetime.combine(datetime.today(), time(0, 0)) + timedelta(minutes=x)).time() for x in range(200) + ], + "null_col": [None if i % 2 == 1 else f"not_null" for i in range(num_rows)], + "non_uniformly_distributed_col": [1 for _ in range(num_rows)], + } + return pd.DataFrame(data) + + +@pytest.fixture(scope="session", params=URLS) +def test_db(request: pytest.FixtureRequest, generated_data: pd.DataFrame) -> Generator[str, None, None]: + db_url = request.param + setup_database(db_url, generated_data) + yield db_url + + +@tenacity.retry(stop=tenacity.stop_after_delay(10), wait=tenacity.wait_fixed(5), reraise=True) +def setup_database(db_url: str, data: pd.DataFrame) -> None: + engine = create_engine(db_url) + create_and_populate(engine, data) + + # Ensure the table is created and populated + with engine.connect() as conn: + result = conn.execute(text(f"SELECT COUNT(*) FROM {TEST_TABLE_NAME}")).fetchone()[0] + assert result == len(data) + + +def create_and_populate(engine: Engine, data: pd.DataFrame) -> None: + metadata = MetaData() + table = Table( + TEST_TABLE_NAME, + metadata, + Column("id", Integer), + Column("float_col", Float), + Column("string_col", String(50)), + Column("bool_col", Boolean), + Column("date_col", Date), + Column("date_time_col", DateTime), + Column("null_col", String(50)), + Column("non_uniformly_distributed_col", Integer), + ) + metadata.create_all(engine) + data.to_sql(table.name, con=engine, if_exists="replace", index=False) diff --git a/tests/integration/sql/docker-compose/docker-compose.yml b/tests/integration/sql/docker-compose/docker-compose.yml new file mode 100644 index 0000000000..11c391b0d3 --- /dev/null +++ b/tests/integration/sql/docker-compose/docker-compose.yml @@ -0,0 +1,36 @@ +version: '3.7' +services: + trino: + image: trinodb/trino + container_name: trino + ports: + - 8080:8080 + + postgres: + image: postgres:latest + container_name: postgres + environment: + POSTGRES_DB: postgres + POSTGRES_USER: username + POSTGRES_PASSWORD: password + ports: + - 5432:5432 + volumes: + - postgres_data:/var/lib/postgresql/data + + mysql: + image: mysql:latest + container_name: mysql + environment: + MYSQL_DATABASE: mysql + MYSQL_USER: username + MYSQL_PASSWORD: password + MYSQL_ROOT_PASSWORD: rootpassword + ports: + - 3306:3306 + volumes: + - mysql_data:/var/lib/mysql + +volumes: + postgres_data: + mysql_data: diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py new file mode 100644 index 0000000000..d664c2ef80 --- /dev/null +++ b/tests/integration/sql/test_sql.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import datetime +import math + +import pandas as pd +import pytest + +import daft +from daft.context import set_execution_config +from tests.conftest import assert_df_equals +from tests.integration.sql.conftest import TEST_TABLE_NAME + + +@pytest.mark.integration() +def test_sql_show(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + df.show() + + +@pytest.mark.integration() +def test_sql_create_dataframe_ok(test_db) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [2, 3, 4]) +def test_sql_partitioned_read(test_db, num_partitions) -> None: + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + + row_size_bytes = daft.from_pandas(pdf).schema().estimate_row_size_bytes() + num_rows_per_partition = len(pdf) / num_partitions + set_execution_config(read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition)) + + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id") + assert df.num_partitions() == num_partitions + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) +@pytest.mark.parametrize("partition_col", ["id", "float_col", "date_col", "date_time_col"]) +def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( + test_db, num_partitions, partition_col +) -> None: + df = daft.read_sql( + f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col=partition_col, num_partitions=num_partitions + ) + assert df.num_partitions() == num_partitions + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) +def test_sql_partitioned_read_with_non_uniformly_distributed_column(test_db, num_partitions) -> None: + df = daft.read_sql( + f"SELECT * FROM {TEST_TABLE_NAME}", + test_db, + partition_col="non_uniformly_distributed_col", + num_partitions=num_partitions, + ) + assert df.num_partitions() == num_partitions + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("partition_col", ["string_col", "time_col", "null_col"]) +def test_sql_partitioned_read_with_non_partionable_column(test_db, partition_col) -> None: + with pytest.raises(ValueError, match="Failed to get partition bounds"): + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col=partition_col, num_partitions=2) + df = df.collect() + + +@pytest.mark.integration() +def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: + with pytest.raises(ValueError, match="Failed to execute sql"): + daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, num_partitions=2) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "operator", + ["<", ">", "=", "!=", ">=", "<="], +) +@pytest.mark.parametrize( + "column, value", + [ + ("id", 100), + ("float_col", 100.0), + ("string_col", "row_100"), + ("bool_col", True), + ("date_col", datetime.date(2021, 1, 1)), + ("date_time_col", datetime.datetime(2020, 1, 1, 10, 0, 0)), + ], +) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + + if operator == ">": + df = df.where(df[column] > value) + pdf = pdf[pdf[column] > value] + elif operator == "<": + df = df.where(df[column] < value) + pdf = pdf[pdf[column] < value] + elif operator == "=": + df = df.where(df[column] == value) + pdf = pdf[pdf[column] == value] + elif operator == "!=": + df = df.where(df[column] != value) + pdf = pdf[pdf[column] != value] + elif operator == ">=": + df = df.where(df[column] >= value) + pdf = pdf[pdf[column] >= value] + elif operator == "<=": + df = df.where(df[column] <= value) + pdf = pdf[pdf[column] <= value] + + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_is_null_filter_pushdowns(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + df = df.where(df["null_col"].is_null()) + + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pdf[pdf["null_col"].isnull()] + + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + df = df.where(df["null_col"].not_null()) + + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pdf[pdf["null_col"].notnull()] + + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50)) + + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)] + + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_is_in_filter_pushdown(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + df = df.where(df["id"].is_in([1, 2, 3])) + + pdf = pd.read_sql_query(f"SELECT * FROM {TEST_TABLE_NAME}", test_db) + pdf = pdf[pdf["id"].isin([1, 2, 3])] + assert_df_equals(df.to_pandas(), pdf, sort_key="id") + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_all_pushdowns(test_db, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + df = df.where(~(df["id"] < 1)) + df = df.where(df["string_col"].is_in([f"row_{i}" for i in range(10)])) + df = df.select(df["id"], df["float_col"], df["string_col"]) + df = df.limit(5) + + df = df.collect() + assert len(df) == 5 + assert df.column_names == ["id", "float_col", "string_col"] + + pydict = df.to_pydict() + assert all(i >= 1 for i in pydict["id"]) + assert all(s in [f"row_{i}" for i in range(10)] for s in pydict["string_col"]) + + +@pytest.mark.integration() +@pytest.mark.parametrize("limit", [0, 1, 10, 100, 200]) +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_limit_pushdown(test_db, limit, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + df = df.limit(limit) + + df = df.collect() + assert len(df) == limit + + +@pytest.mark.integration() +@pytest.mark.parametrize("num_partitions", [1, 2]) +def test_sql_read_with_projection_pushdown(test_db, generated_data, num_partitions) -> None: + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions) + df = df.select(df["id"], df["string_col"]) + + df = df.collect() + assert df.column_names == ["id", "string_col"] + assert len(df) == len(generated_data) + + +@pytest.mark.integration() +def test_sql_bad_url() -> None: + with pytest.raises(RuntimeError, match="Failed to execute sql"): + daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", "bad_url://") diff --git a/tests/io/test_csv_roundtrip.py b/tests/io/test_csv_roundtrip.py index 2043e30a1b..dd288e6806 100644 --- a/tests/io/test_csv_roundtrip.py +++ b/tests/io/test_csv_roundtrip.py @@ -50,6 +50,12 @@ # NOTE: Seems like the inferred type is seconds because it's written with seconds resolution DataType.timestamp(TimeUnit.s()), ), + ( + [datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], + pa.date64(), + DataType.timestamp(TimeUnit.ms()), + DataType.timestamp(TimeUnit.s()), + ), ( [datetime.timedelta(days=1), datetime.timedelta(days=2), None], pa.duration("ms"), diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 81569f1be2..1b8f8df619 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -42,6 +42,7 @@ pa.timestamp("ms"), DataType.timestamp(TimeUnit.ms()), ), + ([datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], pa.date64(), DataType.timestamp(TimeUnit.ms())), ( [datetime.timedelta(days=1), datetime.timedelta(days=2), None], pa.duration("ms"), diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index 95012309be..bc8d2be888 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -94,6 +94,7 @@ "binary": pa.array(PYTHON_TYPE_ARRAYS["binary"], pa.binary()), "boolean": pa.array(PYTHON_TYPE_ARRAYS["bool"], pa.bool_()), "date32": pa.array(PYTHON_TYPE_ARRAYS["date"], pa.date32()), + "date64": pa.array(PYTHON_TYPE_ARRAYS["date"], pa.date64()), "time64_microseconds": pa.array(PYTHON_TYPE_ARRAYS["time"], pa.time64("us")), "time64_nanoseconds": pa.array(PYTHON_TYPE_ARRAYS["time"], pa.time64("ns")), "list": pa.array(PYTHON_TYPE_ARRAYS["list"], pa.list_(pa.int64())), @@ -149,6 +150,7 @@ "binary": pa.large_binary(), "boolean": pa.bool_(), "date32": pa.date32(), + "date64": pa.timestamp("ms"), "time64_microseconds": pa.time64("us"), "time64_nanoseconds": pa.time64("ns"), "list": pa.large_list(pa.int64()),