Skip to content

Commit

Permalink
move retries out of fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-ho committed Feb 24, 2024
1 parent f0edfde commit c4fbef6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 49 deletions.
33 changes: 28 additions & 5 deletions tests/integration/sql/conftest.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
from __future__ import annotations

import random

import numpy as np
import pytest
import sqlalchemy
import tenacity

TRINO_URL = "trino://user@localhost:8080/tpch"

NUM_TEST_ROWS = 200


@pytest.fixture(scope="session")
def test_items():
np.random.seed(42)
data = {
"sepal_length": np.round(np.random.uniform(4.3, 7.9, NUM_TEST_ROWS), 1),
"sepal_width": np.round(np.random.uniform(2.0, 4.4, NUM_TEST_ROWS), 1),
"petal_length": np.round(np.random.uniform(1.0, 6.9, NUM_TEST_ROWS), 1),
"petal_width": np.round(np.random.uniform(0.1, 2.5, NUM_TEST_ROWS), 1),
"variety": [random.choice(["Setosa", "Versicolor", "Virginica"]) for _ in range(NUM_TEST_ROWS)],
}
return data


@tenacity.retry(
stop=tenacity.stop_after_delay(60),
wait=tenacity.wait_fixed(5),
reraise=True,
)
def check_database_connection(url) -> None:
with sqlalchemy.create_engine(url).connect() as conn:
conn.execute("SELECT 1")


@pytest.fixture(scope="session")
def check_db_server_initialized() -> None:
@pytest.mark.parametrize("url", [TRINO_URL])
def check_db_server_initialized(url) -> bool:
try:
with sqlalchemy.create_engine(TRINO_URL).connect() as conn:
conn.execute(sqlalchemy.text("SELECT 1"))
check_database_connection(url)
return True
except Exception as e:
print(f"Connection failed with exception: {e}")
raise
pytest.fail(f"Failed to connect to {url}: {e}")
55 changes: 14 additions & 41 deletions tests/integration/sql/test_sql_lite.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,32 @@
from __future__ import annotations

import random
import sqlite3
import tempfile

import numpy as np
import pandas as pd
import pytest

import daft

COL_NAMES = ["sepal_length", "sepal_width", "petal_length", "petal_width", "variety"]
VARIETIES = ["Setosa", "Versicolor", "Virginica"]
CREATE_TABLE_SQL = """
CREATE TABLE iris (
sepal_length REAL,
sepal_width REAL,
petal_length REAL,
petal_width REAL,
variety TEXT
)
"""
INSERT_SQL = "INSERT INTO iris VALUES (?, ?, ?, ?, ?)"
NUM_ITEMS = 200


def generate_test_items(num_items):
np.random.seed(42)
data = {
"sepal_length": np.round(np.random.uniform(4.3, 7.9, num_items), 1),
"sepal_width": np.round(np.random.uniform(2.0, 4.4, num_items), 1),
"petal_length": np.round(np.random.uniform(1.0, 6.9, num_items), 1),
"petal_width": np.round(np.random.uniform(0.1, 2.5, num_items), 1),
"variety": [random.choice(VARIETIES) for _ in range(num_items)],
}
return [
(
data["sepal_length"][i],
data["sepal_width"][i],
data["petal_length"][i],
data["petal_width"][i],
data["variety"][i],
)
for i in range(num_items)
]


# Fixture for temporary SQLite database
@pytest.fixture(scope="module")
def temp_sqllite_db():
test_items = generate_test_items(NUM_ITEMS)
def temp_sqllite_db(test_items):
data = list(
zip(
test_items["sepal_length"],
test_items["sepal_width"],
test_items["petal_length"],
test_items["petal_width"],
test_items["variety"],
)
)
with tempfile.NamedTemporaryFile(suffix=".db") as file:
connection = sqlite3.connect(file.name)
connection.execute(CREATE_TABLE_SQL)
connection.executemany(INSERT_SQL, test_items)
connection.execute(
"CREATE TABLE iris (sepal_length REAL, sepal_width REAL, petal_length REAL, petal_width REAL, variety TEXT)"
)
connection.executemany("INSERT INTO iris VALUES (?, ?, ?, ?, ?)", data)
connection.commit()
connection.close()
yield file.name
Expand Down
7 changes: 4 additions & 3 deletions tests/integration/sql/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

@pytest.mark.integration()
def test_trino_create_dataframe_ok(check_db_server_initialized) -> None:
df = daft.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL)
pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL)
assert df.to_pandas().equals(pd_df)
if check_db_server_initialized:
df = daft.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL)
pd_df = pd.read_sql("SELECT * FROM tpch.sf1.nation", TRINO_URL)
assert df.to_pandas().equals(pd_df)

0 comments on commit c4fbef6

Please sign in to comment.