diff --git a/tests/integration/sql/conftest.py b/tests/integration/sql/conftest.py index 82ca03e8d7..f20e7bb2f6 100644 --- a/tests/integration/sql/conftest.py +++ b/tests/integration/sql/conftest.py @@ -1,13 +1,14 @@ from __future__ import annotations import random +from typing import Generator import numpy as np import pytest import sqlalchemy import tenacity -TRINO_URL = "trino://user@localhost:8080/tpch" +URLS = {"trino": "trino://user@localhost:8080/tpch"} NUM_TEST_ROWS = 200 @@ -36,10 +37,14 @@ def check_database_connection(url) -> None: @pytest.fixture(scope="session") -@pytest.mark.parametrize("url", [TRINO_URL]) -def check_db_server_initialized(url) -> bool: - try: - check_database_connection(url) - return True - except Exception as e: - pytest.fail(f"Failed to connect to {url}: {e}") +def db_url() -> Generator[str, None, None]: + for url in URLS.values(): + try: + check_database_connection(url) + except Exception as e: + pytest.fail(f"Failed to connect to {url}: {e}") + + def db_url(db): + return URLS[db] + + yield db_url diff --git a/tests/integration/sql/test_databases.py b/tests/integration/sql/test_databases.py new file mode 100644 index 0000000000..e98f11d72e --- /dev/null +++ b/tests/integration/sql/test_databases.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +import daft + + +@pytest.mark.integration() +def test_trino_create_dataframe_ok(db_url) -> None: + url = db_url("trino") + df = daft.read_sql("SELECT * FROM tpch.sf1.nation", url) + pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", url) + + assert df.equals(pd_df) diff --git a/tests/integration/sql/test_sql_lite.py b/tests/integration/sql/test_operations.py similarity index 99% rename from tests/integration/sql/test_sql_lite.py rename to tests/integration/sql/test_operations.py index 990ce87603..c62df97214 100644 --- a/tests/integration/sql/test_sql_lite.py +++ b/tests/integration/sql/test_operations.py @@ -36,7 +36,7 @@ def temp_sqllite_db(test_items): def test_sqllite_create_dataframe_ok(temp_sqllite_db) -> None: df = daft.read_sql( "SELECT * FROM iris", f"sqlite://{temp_sqllite_db}" - ) # path here only has 2 slashes instead of 3 because connectorx is used + ) # path here only has 2 slashes instead of 3 because connectorx uses 2 slashes pd_df = pd.read_sql("SELECT * FROM iris", f"sqlite:///{temp_sqllite_db}") assert df.to_pandas().equals(pd_df) diff --git a/tests/integration/sql/test_trino.py b/tests/integration/sql/test_trino.py deleted file mode 100644 index 99efa8e990..0000000000 --- a/tests/integration/sql/test_trino.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -import pandas as pd -import pytest - -import daft -from tests.integration.sql.conftest import TRINO_URL - - -@pytest.mark.integration() -def test_trino_create_dataframe_ok(check_db_server_initialized) -> None: - if check_db_server_initialized: - df = daft.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) - pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL) - assert df.to_pandas().equals(pd_df)