From 459ba82bf71693c6a57f854a1ff4eda77abe433f Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 22 Oct 2024 17:17:00 -0700 Subject: [PATCH] [BUG] Add over clause in read_sql percentile reads (#3094) Addresses: https://github.com/Eventual-Inc/Daft/issues/3075 SQL server requires an `OVER` clause to be specified in percentile queries (because it's a window function). Read sql uses percentiles to determine partition bounds. Adds AzureSqlEdge as a test database. Might as well since a lot of ppl use us to read sqlserver, and have had bugs with sql server. Kind of a pain to get it set up since it requires odbc and drivers etc. but it works. It's also not much of a hit on CI times, installing drivers takes around ~15s and the extra tests take around 5s. Additionally made some modifications to some tests and pushdowns, left comments on the rationale. --------- Co-authored-by: Colin Ho Co-authored-by: Colin Ho --- .github/workflows/python-package.yml | 6 ++++++ daft/sql/sql_scan.py | 8 ++++++-- requirements-dev.txt | 1 + src/daft-dsl/src/expr/mod.rs | 16 ++-------------- tests/integration/sql/conftest.py | 1 + .../sql/docker-compose/docker-compose.yml | 12 ++++++++++++ tests/integration/sql/test_sql.py | 8 +++++++- 7 files changed, 35 insertions(+), 17 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 3affeecc4c..0a7d2de10a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -582,6 +582,12 @@ jobs: run: | uv pip install -r requirements-dev.txt dist/${{ env.package-name }}-*x86_64*.whl --force-reinstall rm -rf daft + - name: Install ODBC Driver 18 for SQL Server + run: | + curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add - + sudo add-apt-repository https://packages.microsoft.com/ubuntu/$(lsb_release -rs)/prod + sudo apt-get update + sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 - name: Spin up services run: | pushd ./tests/integration/sql/docker-compose/ diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 4f0f9a35c7..4d3156ae80 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -161,14 +161,18 @@ def _get_num_rows(self) -> int: def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, PartitionBoundStrategy]: try: - # Try to get percentiles using percentile_cont + # Try to get percentiles using percentile_disc. + # Favor percentile_disc over percentile_cont because we want exact values to do <= and >= comparisons. percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] + # Use the OVER clause for SQL Server + over_clause = "OVER ()" if self.conn.dialect in ["mssql", "tsql"] else "" percentile_sql = self.conn.construct_sql_query( self.sql, projection=[ - f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" + f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) {over_clause} AS bound_{i}" for i, percentile in enumerate(percentiles) ], + limit=1, ) pa_table = self.conn.execute_sql_query(percentile_sql) return pa_table, PartitionBoundStrategy.PERCENTILE diff --git a/requirements-dev.txt b/requirements-dev.txt index 9c7809ac80..3ab91623eb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -66,6 +66,7 @@ trino[sqlalchemy]==0.328.0; python_version >= '3.8' PyMySQL==1.1.0; python_version >= '3.8' psycopg2-binary==2.9.9; python_version >= '3.8' sqlglot==23.3.0; python_version >= '3.8' +pyodbc==5.1.0; python_version >= '3.8' # AWS s3fs==2023.12.0; python_version >= '3.8' diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 873f9013bd..567a2d35d8 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -990,21 +990,9 @@ impl Expr { to_sql_inner(inner, buffer)?; write!(buffer, ") IS NOT NULL") } - Expr::IfElse { - if_true, - if_false, - predicate, - } => { - write!(buffer, "CASE WHEN ")?; - to_sql_inner(predicate, buffer)?; - write!(buffer, " THEN ")?; - to_sql_inner(if_true, buffer)?; - write!(buffer, " ELSE ")?; - to_sql_inner(if_false, buffer)?; - write!(buffer, " END") - } // TODO: Implement SQL translations for these expressions if possible - Expr::Agg(..) + Expr::IfElse { .. } + | Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) | Expr::Between(..) diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index f5c01dccc6..e202eed471 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -26,6 +26,7 @@ "trino://user@localhost:8080/memory/default", "postgresql://username:password@localhost:5432/postgres", "mysql+pymysql://username:password@localhost:3306/mysql", + "mssql+pyodbc://SA:StrongPassword!@127.0.0.1:1433/master?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes", ] TEST_TABLE_NAME = "example" EMPTY_TEST_TABLE_NAME = "empty_table" diff --git a/tests/integration/sql/docker-compose/docker-compose.yml b/tests/integration/sql/docker-compose/docker-compose.yml index 11c391b0d3..b8eb8c3eba 100644 --- a/tests/integration/sql/docker-compose/docker-compose.yml +++ b/tests/integration/sql/docker-compose/docker-compose.yml @@ -31,6 +31,18 @@ services: volumes: - mysql_data:/var/lib/mysql + azuresqledge: + image: mcr.microsoft.com/azure-sql-edge + container_name: azuresqledge + environment: + ACCEPT_EULA: "Y" + MSSQL_SA_PASSWORD: "StrongPassword!" + ports: + - 1433:1433 + volumes: + - azuresqledge_data:/var/opt/mssql + volumes: postgres_data: mysql_data: + azuresqledge_data: diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index ff02ebaac4..7983be00c7 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -141,6 +141,10 @@ def test_sql_read_with_partition_num_without_partition_col(test_db) -> None: ) @pytest.mark.parametrize("num_partitions", [1, 2]) def test_sql_read_with_binary_filter_pushdowns(test_db, column, operator, value, num_partitions, pdf) -> None: + # Skip invalid comparisons for bool_col + if column == "bool_col" and operator not in ("=", "!="): + pytest.skip(f"Operator {operator} not valid for bool_col") + df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, @@ -204,13 +208,15 @@ def test_sql_read_with_not_null_filter_pushdowns(test_db, num_partitions, pdf) - @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2]) -def test_sql_read_with_if_else_filter_pushdown(test_db, num_partitions, pdf) -> None: +def test_sql_read_with_non_pushdowned_predicate(test_db, num_partitions, pdf) -> None: df = daft.read_sql( f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id", num_partitions=num_partitions, ) + + # If_else is not supported as a pushdown to read_sql, but it should still work df = df.where((df["id"] > 100).if_else(df["float_col"] > 150, df["float_col"] < 50)) pdf = pdf[(pdf["id"] > 100) & (pdf["float_col"] > 150) | (pdf["float_col"] < 50)]