Skip to content

Commit

Permalink
test: merge sync/async tests and fix warnings (#28)
Browse files Browse the repository at this point in the history
This commit merges our sync and async tests so that now they use the
`maybe_await` function to reuse logic.
  • Loading branch information
serramatutu authored Jul 9, 2024
1 parent ff875b2 commit 6e39e83
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 54 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Test-20240709-114648.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Test
body: Fix warnings and failing integration test
time: 2024-07-09T11:46:48.042286+02:00
3 changes: 3 additions & 0 deletions tests/api/graphql/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import AsyncMock, MagicMock, call

import pyarrow as pa
import pytest
from pytest_mock import MockerFixture

from dbtsl.api.graphql.client.asyncio import AsyncGraphQLClient
Expand Down Expand Up @@ -80,6 +81,8 @@ async def run_behavior(op: ProtocolOperation, query_id: QueryId, page_num: int)
assert result_table.equals(table, check_metadata=True)


# avoid raising mock warning related to mocking a context manager
@pytest.mark.filterwarnings("ignore::pytest_mock.PytestMockWarning")
def test_sync_query_multiple_pages(mocker: MockerFixture) -> None:
"""Test that querying a dataframe with multiple pages works."""
client = SyncGraphQLClient(server_host="test", environment_id=0, auth_token="test")
Expand Down
15 changes: 1 addition & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Callable, Iterator, Union, cast
from typing import Callable, Union, cast

import pytest
from gql import Client, gql
Expand Down Expand Up @@ -72,15 +71,3 @@ def from_env(cls) -> "Credentials":
@pytest.fixture(scope="session")
def credentials() -> Credentials:
return Credentials.from_env()


@pytest.fixture(scope="session")
def event_loop() -> Iterator[asyncio.AbstractEventLoop]:
"""Override pytest-asyncio's default `event_loop` fixture.
We add scope='session' to make all tests share the same event loop.
This avoids concurrency issues related to opening and closing sessions.
"""
loop = asyncio.get_event_loop()
yield loop
loop.close()
84 changes: 44 additions & 40 deletions tests/integration/test_gql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import AsyncIterator, Iterator
import inspect
from typing import AsyncIterator, Awaitable, Iterator, TypeVar, Union

import pytest

Expand All @@ -7,8 +8,25 @@

from ..conftest import Credentials

BothClients = Union[SyncGraphQLClient, AsyncGraphQLClient]

@pytest.fixture(scope="session")

def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
if "client" in metafunc.fixturenames:
metafunc.parametrize("client", ["sync", "async"], indirect=True)


T = TypeVar("T")


async def maybe_await(coro: Union[Awaitable[T], T]) -> T:
if inspect.iscoroutine(coro):
return await coro

return coro # type: ignore


@pytest.fixture(scope="module")
async def async_client(credentials: Credentials) -> AsyncIterator[AsyncGraphQLClient]:
client = AsyncGraphQLClient(
environment_id=credentials.environment_id,
Expand All @@ -19,7 +37,7 @@ async def async_client(credentials: Credentials) -> AsyncIterator[AsyncGraphQLCl
yield client


@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def sync_client(credentials: Credentials) -> Iterator[SyncGraphQLClient]:
client = SyncGraphQLClient(
environment_id=credentials.environment_id,
Expand All @@ -30,59 +48,45 @@ def sync_client(credentials: Credentials) -> Iterator[SyncGraphQLClient]:
yield client


def test_sync_client_lists_metrics_dimensions_entities(sync_client: SyncGraphQLClient) -> None:
metrics = sync_client.metrics()
assert len(metrics) > 0
@pytest.fixture(scope="module")
async def client(
request: pytest.FixtureRequest, sync_client: SyncGraphQLClient, async_client: AsyncGraphQLClient
) -> BothClients:
if request.param == "sync":
return sync_client

dims = sync_client.dimensions(metrics=[metrics[0].name])
assert len(dims) > 0
assert dims == metrics[0].dimensions
return async_client

entities = sync_client.entities(metrics=[metrics[0].name])
assert len(entities) > 0
assert entities == metrics[0].entities

pytestmark = pytest.mark.asyncio(scope="module")


async def test_async_client_lists_metrics_dimensions_entities(async_client: AsyncGraphQLClient) -> None:
metrics = await async_client.metrics()
async def test_client_lists_metrics_dimensions_entities(client: BothClients) -> None:
metrics = await maybe_await(client.metrics())
assert len(metrics) > 0

dims = await async_client.dimensions(metrics=[metrics[0].name])
dims = await maybe_await(client.dimensions(metrics=[metrics[0].name]))
assert len(dims) > 0
assert dims == metrics[0].dimensions

entities = await async_client.entities(metrics=[metrics[0].name])
entities = await maybe_await(client.entities(metrics=[metrics[0].name]))
assert len(entities) > 0
assert entities == metrics[0].entities


def test_sync_client_lists_saved_queries(sync_client: SyncGraphQLClient) -> None:
sqs = sync_client.saved_queries()
assert len(sqs) > 0


async def test_async_client_lists_saved_queries(async_client: AsyncGraphQLClient) -> None:
sqs = await async_client.saved_queries()
async def test_client_lists_saved_queries(client: BothClients) -> None:
sqs = await maybe_await(client.saved_queries())
assert len(sqs) > 0


def test_sync_client_query_works(sync_client: SyncGraphQLClient) -> None:
metrics = sync_client.metrics()
assert len(metrics) > 0
table = sync_client.query(
metrics=[metrics[0].name],
group_by=["metric_time"],
limit=1,
)
assert len(table) > 0


async def test_async_client_query_works(async_client: AsyncGraphQLClient) -> None:
metrics = await async_client.metrics()
async def test_client_query_works(client: BothClients) -> None:
metrics = await maybe_await(client.metrics())
assert len(metrics) > 0
table = await async_client.query(
metrics=[metrics[0].name],
group_by=["metric_time"],
limit=1,
table = await maybe_await(
client.query(
metrics=[metrics[0].name],
group_by=["metric_time"],
limit=1,
)
)
assert len(table) > 0

0 comments on commit 6e39e83

Please sign in to comment.