Skip to content

Commit

Permalink
refactor: change how clients poll and create queries (#21)
Browse files Browse the repository at this point in the history
* internal: remove unnecessary logic for `create_query` and `get_query_result`

* refactor: make `_poll_until_complete` generic

This will allow us to poll for the completion of many types of jobs, not
only queries.

* refactor: refactor internal polling for sync client also

Same as previous commit, but now for the sync client.
  • Loading branch information
serramatutu authored Jun 20, 2024
1 parent 75066cd commit 6bf878e
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 65 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Under the Hood-20240620-130850.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Under the Hood
body: Change how clients create and poll for query results
time: 2024-06-20T13:08:50.067409+02:00
37 changes: 16 additions & 21 deletions dbtsl/api/graphql/client/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@

from dbtsl.api.graphql.client.base import BaseGraphQLClient
from dbtsl.api.graphql.protocol import (
GetQueryResultVariables,
ProtocolOperation,
TJobStatusResult,
TJobStatusVariables,
TResponse,
TVariables,
)
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.backoff import ExponentialBackoff
from dbtsl.error import QueryFailedError
from dbtsl.models.query import QueryId, QueryResult, QueryStatus
from dbtsl.models.query import QueryId, QueryStatus


class AsyncGraphQLClient(BaseGraphQLClient[AIOHTTPTransport, AsyncClientSession]):
Expand Down Expand Up @@ -69,35 +70,29 @@ async def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVa
variables = op.get_request_variables(environment_id=self.environment_id, **kwargs)
gql_query = gql(raw_query)

res = await self._gql_session.execute(gql_query, variable_values=variables)
try:
res = await self._gql_session.execute(gql_query, variable_values=variables)
except Exception as err:
raise self._refine_err(err)

return op.parse_response(res)

async def _create_query(self, **params: Unpack[QueryParameters]) -> QueryId:
"""Create a query that will run asynchronously."""
return await self._run(self.PROTOCOL.create_query, **params) # type: ignore

async def _get_query_result(self, **params: Unpack[GetQueryResultVariables]) -> QueryResult:
"""Fetch a query's results'."""
return await self._run(self.PROTOCOL.get_query_result, **params) # type: ignore

async def _poll_until_complete(
self,
query_id: QueryId,
poll_op: ProtocolOperation[TJobStatusVariables, TJobStatusResult],
backoff: Optional[ExponentialBackoff] = None,
) -> QueryResult:
"""Poll for a query's results until it is in a completed state (SUCCESSFUL or FAILED).
Note that this function does NOT fetch all pages in case the query is SUCCESSFUL. It only
returns once the query is done. Callers must implement this logic themselves.
"""
**kwargs,
) -> TJobStatusResult:
"""Poll for a job's results until it is in a completed state (SUCCESSFUL or FAILED)."""
if backoff is None:
backoff = self._default_backoff()

for sleep_ms in backoff.iter_ms():
# TODO: add timeout param to all requests because technically the API could hang and
# then we don't respect timeout.
qr = await self._get_query_result(query_id=query_id, page_num=1)
kwargs["query_id"] = query_id
qr = await self._run(poll_op, **kwargs)
if qr.status in (QueryStatus.SUCCESSFUL, QueryStatus.FAILED):
return qr

Expand All @@ -108,8 +103,8 @@ async def _poll_until_complete(

async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
query_id = await self._create_query(**params)
first_page_results = await self._poll_until_complete(query_id)
query_id = await self.create_query(**params)
first_page_results = await self._poll_until_complete(query_id, self.PROTOCOL.get_query_result, page_num=1)
if first_page_results.status != QueryStatus.SUCCESSFUL:
raise QueryFailedError()

Expand All @@ -119,7 +114,7 @@ async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
return first_page_results.result_table

tasks = [
self._get_query_result(query_id=query_id, page_num=page)
self.get_query_result(query_id=query_id, page_num=page)
for page in range(2, first_page_results.total_pages + 1)
]
all_page_results = [first_page_results] + await asyncio.gather(*tasks)
Expand Down
27 changes: 10 additions & 17 deletions dbtsl/api/graphql/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
import dbtsl.env as env
from dbtsl.api.graphql.protocol import (
GraphQLProtocol,
ProtocolOperation,
TResponse,
TVariables,
)
from dbtsl.backoff import ExponentialBackoff
from dbtsl.error import AuthError
Expand Down Expand Up @@ -72,20 +69,16 @@ def _create_transport(self, url: str, headers: Dict[str, str]) -> TTransport:
"""Create the underlying transport to be used by the gql Client."""
raise NotImplementedError()

@abstractmethod
def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariables) -> TResponse:
raise NotImplementedError()

def _run_err_wrapper(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariables) -> TResponse:
try:
return self._run(op, **kwargs)
except TransportQueryError as err:
# TODO: we should probably return an error type that has an Enum from GraphQL
# instead of depending on error messages
if err.errors is not None and err.errors[0]["message"] == "User is not authorized":
raise AuthError(err.args)
def _refine_err(self, err: Exception) -> Exception:
"""Refine a generic exception that might have happened during `_run`."""
if (
isinstance(err, TransportQueryError)
and err.errors is not None
and err.errors[0]["message"] == "User is not authorized"
):
return AuthError(err.args)

raise err
return err

@property
def _gql_session(self) -> TSession:
Expand All @@ -110,7 +103,7 @@ def __getattr__(self, attr: str) -> Any:
raise AttributeError()

return functools.partial(
self._run_err_wrapper,
self._run,
op=op,
)

Expand Down
31 changes: 15 additions & 16 deletions dbtsl/api/graphql/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@

from dbtsl.api.graphql.client.base import BaseGraphQLClient
from dbtsl.api.graphql.protocol import (
GetQueryResultVariables,
ProtocolOperation,
TJobStatusResult,
TJobStatusVariables,
TResponse,
TVariables,
)
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.backoff import ExponentialBackoff
from dbtsl.error import QueryFailedError
from dbtsl.models.query import QueryId, QueryResult, QueryStatus
from dbtsl.models.query import QueryId, QueryStatus


class SyncGraphQLClient(BaseGraphQLClient[RequestsHTTPTransport, SyncClientSession]):
Expand Down Expand Up @@ -69,23 +70,20 @@ def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariable
variables = op.get_request_variables(environment_id=self.environment_id, **kwargs)
gql_query = gql(raw_query)

res = self._gql_session.execute(gql_query, variable_values=variables)
try:
res = self._gql_session.execute(gql_query, variable_values=variables)
except Exception as err:
raise self._refine_err(err)

return op.parse_response(res)

def _create_query(self, **params: Unpack[QueryParameters]) -> QueryId:
"""Create a query that will run asynchronously."""
return self._run(self.PROTOCOL.create_query, **params) # type: ignore

def _get_query_result(self, **params: Unpack[GetQueryResultVariables]) -> QueryResult:
"""Fetch a query's results'."""
return self._run(self.PROTOCOL.get_query_result, **params) # type: ignore

def _poll_until_complete(
self,
query_id: QueryId,
poll_op: ProtocolOperation[TJobStatusVariables, TJobStatusResult],
backoff: Optional[ExponentialBackoff] = None,
) -> QueryResult:
**kwargs,
) -> TJobStatusResult:
"""Poll for a query's results until it is in a completed state (SUCCESSFUL or FAILED).
Note that this function does NOT fetch all pages in case the query is SUCCESSFUL. It only
Expand All @@ -97,7 +95,8 @@ def _poll_until_complete(
for sleep_ms in backoff.iter_ms():
# TODO: add timeout param to all requests because technically the API could hang and
# then we don't respect timeout.
qr = self._get_query_result(query_id=query_id, page_num=1)
kwargs["query_id"] = query_id
qr = self._run(poll_op, **kwargs)
if qr.status in (QueryStatus.SUCCESSFUL, QueryStatus.FAILED):
return qr

Expand All @@ -108,8 +107,8 @@ def _poll_until_complete(

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
query_id = self._create_query(**params)
first_page_results = self._poll_until_complete(query_id)
query_id = self.create_query(**params)
first_page_results = self._poll_until_complete(query_id, self.PROTOCOL.get_query_result, page_num=1)
if first_page_results.status != QueryStatus.SUCCESSFUL:
raise QueryFailedError()

Expand All @@ -119,7 +118,7 @@ def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
return first_page_results.result_table

results = [
self._get_query_result(query_id=query_id, page_num=page)
self.get_query_result(query_id=query_id, page_num=page)
for page in range(2, first_page_results.total_pages + 1)
]
all_page_results = [first_page_results] + results
Expand Down
29 changes: 26 additions & 3 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Mapping, TypedDict, TypeVar
from typing import Any, Dict, Generic, List, Mapping, Protocol, TypedDict, TypeVar

from mashumaro.codecs.basic import decode as decode_to_dataclass
from typing_extensions import NotRequired, override

from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.models import Dimension, Measure, Metric
from dbtsl.models.query import QueryId, QueryResult
from dbtsl.models.query import QueryId, QueryResult, QueryStatus


class JobStatusVariables(TypedDict):
"""Variables of operations that will get a job's status."""

query_id: QueryId


class JobStatusResult(Protocol):
"""Result of operations that fetch a job's status."""

@property
def status(self) -> QueryStatus:
"""The job status."""
raise NotImplementedError()


TJobStatusVariables = TypeVar("TJobStatusVariables", bound=JobStatusVariables, covariant=True)

TJobStatusResult = TypeVar("TJobStatusResult", bound=JobStatusResult, covariant=True)


TVariables = TypeVar("TVariables", bound=Mapping[str, Any])
TResponse = TypeVar("TResponse")
# Need to make TResponse covariant otherwise we can't annotate something like
# def func(a: ProtocolOperation[JobStatusVariables, JobStatusResult]) -> JobStatusResult:
TResponse = TypeVar("TResponse", covariant=True)


class ProtocolOperation(Generic[TVariables, TResponse], ABC):
Expand Down
103 changes: 95 additions & 8 deletions tests/api/graphql/test_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import base64
import io
from unittest.mock import call
from unittest.mock import AsyncMock, MagicMock, call

import pyarrow as pa
from pytest_mock import MockerFixture

from dbtsl.api.graphql.client.asyncio import AsyncGraphQLClient
from dbtsl.api.graphql.client.sync import SyncGraphQLClient
from dbtsl.api.graphql.protocol import GraphQLProtocol, ProtocolOperation
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.models.query import QueryId, QueryResult, QueryStatus

# The following 2 tests are copies of each other since testing the same sync/async functionality is
# a pain. I should probably find how to fix this later
#
# These tests are so bad and test a bunch of internals, I hate my life

async def test_query_multiple_pages(mocker: MockerFixture) -> None:

async def test_async_query_multiple_pages(mocker: MockerFixture) -> None:
"""Test that querying a dataframe with multiple pages works."""
client = AsyncGraphQLClient(server_host="test", environment_id=0, auth_token="test")

Expand All @@ -20,7 +27,7 @@ async def test_query_multiple_pages(mocker: MockerFixture) -> None:
)

async def gqr_behavior(query_id: QueryId, page_num: int) -> QueryResult:
"""Behaves like `_get_query_result` but without talking to any servers."""
"""Behaves like `get_query_result` but without talking to any servers."""
call_table = table.slice(offset=page_num - 1, length=1)

byte_stream = io.BytesIO()
Expand All @@ -36,18 +43,98 @@ async def gqr_behavior(query_id: QueryId, page_num: int) -> QueryResult:
arrow_result=base64.b64encode(byte_stream.getvalue()).decode("utf-8"),
)

cq_mock = mocker.patch.object(client, "_create_query", return_value=query_id)
gqr_mock = mocker.patch.object(client, "_get_query_result", side_effect=gqr_behavior)
async def run_behavior(op: ProtocolOperation, query_id: QueryId, page_num: int) -> QueryResult:
return await gqr_behavior(query_id, page_num)

kwargs: QueryParameters = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1}
cq_mock = mocker.patch.object(client, "create_query", return_value=query_id, new_callable=AsyncMock)

result_table = await client.query(**kwargs)
run_mock = AsyncMock(side_effect=run_behavior)
mocker.patch.object(client, "_run", new=run_mock)
gqr_mock = AsyncMock(side_effect=gqr_behavior)
mocker.patch.object(client, "get_query_result", new=gqr_mock)

gql_mock = mocker.patch.object(client, "_gql")
mocker.patch.object(gql_mock, "__aenter__", new_callable=AsyncMock)
mocker.patch("dbtsl.api.graphql.client.asyncio.isinstance", return_value=True)

kwargs: QueryParameters = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1}
async with client.session():
result_table = await client.query(**kwargs)

cq_mock.assert_awaited_once_with(**kwargs)

run_mock.assert_has_awaits(
[
call(GraphQLProtocol.get_query_result, query_id=query_id, page_num=1),
]
)

gqr_mock.assert_has_awaits(
[
call(query_id=query_id, page_num=1),
call(query_id=query_id, page_num=2),
call(query_id=query_id, page_num=3),
call(query_id=query_id, page_num=4),
]
)

assert result_table.equals(table, check_metadata=True)


def test_sync_query_multiple_pages(mocker: MockerFixture) -> None:
"""Test that querying a dataframe with multiple pages works."""
client = SyncGraphQLClient(server_host="test", environment_id=0, auth_token="test")

query_id = QueryId("test-query-id")
table = pa.Table.from_arrays(
[pa.array([2, 4, 6, 100]), pa.array(["Chicken", "Dog", "Ant", "Centipede"])], names=["num_legs", "animal"]
)

def gqr_behavior(query_id: QueryId, page_num: int) -> QueryResult:
"""Behaves like `get_query_result` but without talking to any servers."""
call_table = table.slice(offset=page_num - 1, length=1)

byte_stream = io.BytesIO()
with pa.ipc.new_stream(byte_stream, call_table.schema) as writer:
writer.write_table(call_table)

return QueryResult(
query_id=query_id,
status=QueryStatus.SUCCESSFUL,
sql=None,
error=None,
total_pages=len(table),
arrow_result=base64.b64encode(byte_stream.getvalue()).decode("utf-8"),
)

def run_behavior(op: ProtocolOperation, query_id: QueryId, page_num: int) -> QueryResult:
return gqr_behavior(query_id, page_num)

cq_mock = mocker.patch.object(client, "create_query", return_value=query_id)

run_mock = MagicMock(side_effect=run_behavior)
mocker.patch.object(client, "_run", new=run_mock)
gqr_mock = MagicMock(side_effect=gqr_behavior)
mocker.patch.object(client, "get_query_result", new=gqr_mock)

gql_mock = mocker.patch.object(client, "_gql")
mocker.patch.object(gql_mock, "__aenter__")
mocker.patch("dbtsl.api.graphql.client.sync.isinstance", return_value=True)

kwargs: QueryParameters = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1}

with client.session():
result_table = client.query(**kwargs)

cq_mock.assert_called_once_with(**kwargs)

run_mock.assert_has_calls(
[
call(GraphQLProtocol.get_query_result, query_id=query_id, page_num=1),
]
)

gqr_mock.assert_has_calls(
[
call(query_id=query_id, page_num=2),
call(query_id=query_id, page_num=3),
call(query_id=query_id, page_num=4),
Expand Down

0 comments on commit 6bf878e

Please sign in to comment.