From 6bf878e27c15af100b40bc6ed9ac7786fd874ccb Mon Sep 17 00:00:00 2001 From: Lucas Valente Date: Thu, 20 Jun 2024 19:29:10 +0200 Subject: [PATCH] refactor: change how clients poll and create queries (#21) * internal: remove unnecessary logic for `create_query` and `get_query_result` * refactor: make `_poll_until_complete` generic This will allow us to poll for the completion of many types of jobs, not only queries. * refactor: refactor internal polling for sync client also Same as previous commit, but now for the sync client. --- .../Under the Hood-20240620-130850.yaml | 3 + dbtsl/api/graphql/client/asyncio.py | 37 +++---- dbtsl/api/graphql/client/base.py | 27 ++--- dbtsl/api/graphql/client/sync.py | 31 +++--- dbtsl/api/graphql/protocol.py | 29 ++++- tests/api/graphql/test_client.py | 103 ++++++++++++++++-- 6 files changed, 165 insertions(+), 65 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240620-130850.yaml diff --git a/.changes/unreleased/Under the Hood-20240620-130850.yaml b/.changes/unreleased/Under the Hood-20240620-130850.yaml new file mode 100644 index 0000000..4591642 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240620-130850.yaml @@ -0,0 +1,3 @@ +kind: Under the Hood +body: Change how clients create and poll for query results +time: 2024-06-20T13:08:50.067409+02:00 diff --git a/dbtsl/api/graphql/client/asyncio.py b/dbtsl/api/graphql/client/asyncio.py index 10a9d47..a000885 100644 --- a/dbtsl/api/graphql/client/asyncio.py +++ b/dbtsl/api/graphql/client/asyncio.py @@ -10,15 +10,16 @@ from dbtsl.api.graphql.client.base import BaseGraphQLClient from dbtsl.api.graphql.protocol import ( - GetQueryResultVariables, ProtocolOperation, + TJobStatusResult, + TJobStatusVariables, TResponse, TVariables, ) from dbtsl.api.shared.query_params import QueryParameters from dbtsl.backoff import ExponentialBackoff from dbtsl.error import QueryFailedError -from dbtsl.models.query import QueryId, QueryResult, QueryStatus +from dbtsl.models.query import QueryId, QueryStatus class AsyncGraphQLClient(BaseGraphQLClient[AIOHTTPTransport, AsyncClientSession]): @@ -69,35 +70,29 @@ async def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVa variables = op.get_request_variables(environment_id=self.environment_id, **kwargs) gql_query = gql(raw_query) - res = await self._gql_session.execute(gql_query, variable_values=variables) + try: + res = await self._gql_session.execute(gql_query, variable_values=variables) + except Exception as err: + raise self._refine_err(err) return op.parse_response(res) - async def _create_query(self, **params: Unpack[QueryParameters]) -> QueryId: - """Create a query that will run asynchronously.""" - return await self._run(self.PROTOCOL.create_query, **params) # type: ignore - - async def _get_query_result(self, **params: Unpack[GetQueryResultVariables]) -> QueryResult: - """Fetch a query's results'.""" - return await self._run(self.PROTOCOL.get_query_result, **params) # type: ignore - async def _poll_until_complete( self, query_id: QueryId, + poll_op: ProtocolOperation[TJobStatusVariables, TJobStatusResult], backoff: Optional[ExponentialBackoff] = None, - ) -> QueryResult: - """Poll for a query's results until it is in a completed state (SUCCESSFUL or FAILED). - - Note that this function does NOT fetch all pages in case the query is SUCCESSFUL. It only - returns once the query is done. Callers must implement this logic themselves. - """ + **kwargs, + ) -> TJobStatusResult: + """Poll for a job's results until it is in a completed state (SUCCESSFUL or FAILED).""" if backoff is None: backoff = self._default_backoff() for sleep_ms in backoff.iter_ms(): # TODO: add timeout param to all requests because technically the API could hang and # then we don't respect timeout. - qr = await self._get_query_result(query_id=query_id, page_num=1) + kwargs["query_id"] = query_id + qr = await self._run(poll_op, **kwargs) if qr.status in (QueryStatus.SUCCESSFUL, QueryStatus.FAILED): return qr @@ -108,8 +103,8 @@ async def _poll_until_complete( async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": """Query the Semantic Layer.""" - query_id = await self._create_query(**params) - first_page_results = await self._poll_until_complete(query_id) + query_id = await self.create_query(**params) + first_page_results = await self._poll_until_complete(query_id, self.PROTOCOL.get_query_result, page_num=1) if first_page_results.status != QueryStatus.SUCCESSFUL: raise QueryFailedError() @@ -119,7 +114,7 @@ async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": return first_page_results.result_table tasks = [ - self._get_query_result(query_id=query_id, page_num=page) + self.get_query_result(query_id=query_id, page_num=page) for page in range(2, first_page_results.total_pages + 1) ] all_page_results = [first_page_results] + await asyncio.gather(*tasks) diff --git a/dbtsl/api/graphql/client/base.py b/dbtsl/api/graphql/client/base.py index ca90d6b..2c8ea19 100644 --- a/dbtsl/api/graphql/client/base.py +++ b/dbtsl/api/graphql/client/base.py @@ -10,9 +10,6 @@ import dbtsl.env as env from dbtsl.api.graphql.protocol import ( GraphQLProtocol, - ProtocolOperation, - TResponse, - TVariables, ) from dbtsl.backoff import ExponentialBackoff from dbtsl.error import AuthError @@ -72,20 +69,16 @@ def _create_transport(self, url: str, headers: Dict[str, str]) -> TTransport: """Create the underlying transport to be used by the gql Client.""" raise NotImplementedError() - @abstractmethod - def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariables) -> TResponse: - raise NotImplementedError() - - def _run_err_wrapper(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariables) -> TResponse: - try: - return self._run(op, **kwargs) - except TransportQueryError as err: - # TODO: we should probably return an error type that has an Enum from GraphQL - # instead of depending on error messages - if err.errors is not None and err.errors[0]["message"] == "User is not authorized": - raise AuthError(err.args) + def _refine_err(self, err: Exception) -> Exception: + """Refine a generic exception that might have happened during `_run`.""" + if ( + isinstance(err, TransportQueryError) + and err.errors is not None + and err.errors[0]["message"] == "User is not authorized" + ): + return AuthError(err.args) - raise err + return err @property def _gql_session(self) -> TSession: @@ -110,7 +103,7 @@ def __getattr__(self, attr: str) -> Any: raise AttributeError() return functools.partial( - self._run_err_wrapper, + self._run, op=op, ) diff --git a/dbtsl/api/graphql/client/sync.py b/dbtsl/api/graphql/client/sync.py index 5c2a6c3..e7e8bde 100644 --- a/dbtsl/api/graphql/client/sync.py +++ b/dbtsl/api/graphql/client/sync.py @@ -10,15 +10,16 @@ from dbtsl.api.graphql.client.base import BaseGraphQLClient from dbtsl.api.graphql.protocol import ( - GetQueryResultVariables, ProtocolOperation, + TJobStatusResult, + TJobStatusVariables, TResponse, TVariables, ) from dbtsl.api.shared.query_params import QueryParameters from dbtsl.backoff import ExponentialBackoff from dbtsl.error import QueryFailedError -from dbtsl.models.query import QueryId, QueryResult, QueryStatus +from dbtsl.models.query import QueryId, QueryStatus class SyncGraphQLClient(BaseGraphQLClient[RequestsHTTPTransport, SyncClientSession]): @@ -69,23 +70,20 @@ def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariable variables = op.get_request_variables(environment_id=self.environment_id, **kwargs) gql_query = gql(raw_query) - res = self._gql_session.execute(gql_query, variable_values=variables) + try: + res = self._gql_session.execute(gql_query, variable_values=variables) + except Exception as err: + raise self._refine_err(err) return op.parse_response(res) - def _create_query(self, **params: Unpack[QueryParameters]) -> QueryId: - """Create a query that will run asynchronously.""" - return self._run(self.PROTOCOL.create_query, **params) # type: ignore - - def _get_query_result(self, **params: Unpack[GetQueryResultVariables]) -> QueryResult: - """Fetch a query's results'.""" - return self._run(self.PROTOCOL.get_query_result, **params) # type: ignore - def _poll_until_complete( self, query_id: QueryId, + poll_op: ProtocolOperation[TJobStatusVariables, TJobStatusResult], backoff: Optional[ExponentialBackoff] = None, - ) -> QueryResult: + **kwargs, + ) -> TJobStatusResult: """Poll for a query's results until it is in a completed state (SUCCESSFUL or FAILED). Note that this function does NOT fetch all pages in case the query is SUCCESSFUL. It only @@ -97,7 +95,8 @@ def _poll_until_complete( for sleep_ms in backoff.iter_ms(): # TODO: add timeout param to all requests because technically the API could hang and # then we don't respect timeout. - qr = self._get_query_result(query_id=query_id, page_num=1) + kwargs["query_id"] = query_id + qr = self._run(poll_op, **kwargs) if qr.status in (QueryStatus.SUCCESSFUL, QueryStatus.FAILED): return qr @@ -108,8 +107,8 @@ def _poll_until_complete( def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": """Query the Semantic Layer.""" - query_id = self._create_query(**params) - first_page_results = self._poll_until_complete(query_id) + query_id = self.create_query(**params) + first_page_results = self._poll_until_complete(query_id, self.PROTOCOL.get_query_result, page_num=1) if first_page_results.status != QueryStatus.SUCCESSFUL: raise QueryFailedError() @@ -119,7 +118,7 @@ def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": return first_page_results.result_table results = [ - self._get_query_result(query_id=query_id, page_num=page) + self.get_query_result(query_id=query_id, page_num=page) for page in range(2, first_page_results.total_pages + 1) ] all_page_results = [first_page_results] + results diff --git a/dbtsl/api/graphql/protocol.py b/dbtsl/api/graphql/protocol.py index 8b3d103..c2ded77 100644 --- a/dbtsl/api/graphql/protocol.py +++ b/dbtsl/api/graphql/protocol.py @@ -1,15 +1,38 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Mapping, TypedDict, TypeVar +from typing import Any, Dict, Generic, List, Mapping, Protocol, TypedDict, TypeVar from mashumaro.codecs.basic import decode as decode_to_dataclass from typing_extensions import NotRequired, override from dbtsl.api.shared.query_params import QueryParameters from dbtsl.models import Dimension, Measure, Metric -from dbtsl.models.query import QueryId, QueryResult +from dbtsl.models.query import QueryId, QueryResult, QueryStatus + + +class JobStatusVariables(TypedDict): + """Variables of operations that will get a job's status.""" + + query_id: QueryId + + +class JobStatusResult(Protocol): + """Result of operations that fetch a job's status.""" + + @property + def status(self) -> QueryStatus: + """The job status.""" + raise NotImplementedError() + + +TJobStatusVariables = TypeVar("TJobStatusVariables", bound=JobStatusVariables, covariant=True) + +TJobStatusResult = TypeVar("TJobStatusResult", bound=JobStatusResult, covariant=True) + TVariables = TypeVar("TVariables", bound=Mapping[str, Any]) -TResponse = TypeVar("TResponse") +# Need to make TResponse covariant otherwise we can't annotate something like +# def func(a: ProtocolOperation[JobStatusVariables, JobStatusResult]) -> JobStatusResult: +TResponse = TypeVar("TResponse", covariant=True) class ProtocolOperation(Generic[TVariables, TResponse], ABC): diff --git a/tests/api/graphql/test_client.py b/tests/api/graphql/test_client.py index 0fe8591..2deb116 100644 --- a/tests/api/graphql/test_client.py +++ b/tests/api/graphql/test_client.py @@ -1,16 +1,23 @@ import base64 import io -from unittest.mock import call +from unittest.mock import AsyncMock, MagicMock, call import pyarrow as pa from pytest_mock import MockerFixture from dbtsl.api.graphql.client.asyncio import AsyncGraphQLClient +from dbtsl.api.graphql.client.sync import SyncGraphQLClient +from dbtsl.api.graphql.protocol import GraphQLProtocol, ProtocolOperation from dbtsl.api.shared.query_params import QueryParameters from dbtsl.models.query import QueryId, QueryResult, QueryStatus +# The following 2 tests are copies of each other since testing the same sync/async functionality is +# a pain. I should probably find how to fix this later +# +# These tests are so bad and test a bunch of internals, I hate my life -async def test_query_multiple_pages(mocker: MockerFixture) -> None: + +async def test_async_query_multiple_pages(mocker: MockerFixture) -> None: """Test that querying a dataframe with multiple pages works.""" client = AsyncGraphQLClient(server_host="test", environment_id=0, auth_token="test") @@ -20,7 +27,7 @@ async def test_query_multiple_pages(mocker: MockerFixture) -> None: ) async def gqr_behavior(query_id: QueryId, page_num: int) -> QueryResult: - """Behaves like `_get_query_result` but without talking to any servers.""" + """Behaves like `get_query_result` but without talking to any servers.""" call_table = table.slice(offset=page_num - 1, length=1) byte_stream = io.BytesIO() @@ -36,18 +43,98 @@ async def gqr_behavior(query_id: QueryId, page_num: int) -> QueryResult: arrow_result=base64.b64encode(byte_stream.getvalue()).decode("utf-8"), ) - cq_mock = mocker.patch.object(client, "_create_query", return_value=query_id) - gqr_mock = mocker.patch.object(client, "_get_query_result", side_effect=gqr_behavior) + async def run_behavior(op: ProtocolOperation, query_id: QueryId, page_num: int) -> QueryResult: + return await gqr_behavior(query_id, page_num) - kwargs: QueryParameters = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1} + cq_mock = mocker.patch.object(client, "create_query", return_value=query_id, new_callable=AsyncMock) - result_table = await client.query(**kwargs) + run_mock = AsyncMock(side_effect=run_behavior) + mocker.patch.object(client, "_run", new=run_mock) + gqr_mock = AsyncMock(side_effect=gqr_behavior) + mocker.patch.object(client, "get_query_result", new=gqr_mock) + + gql_mock = mocker.patch.object(client, "_gql") + mocker.patch.object(gql_mock, "__aenter__", new_callable=AsyncMock) + mocker.patch("dbtsl.api.graphql.client.asyncio.isinstance", return_value=True) + + kwargs: QueryParameters = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1} + async with client.session(): + result_table = await client.query(**kwargs) cq_mock.assert_awaited_once_with(**kwargs) + run_mock.assert_has_awaits( + [ + call(GraphQLProtocol.get_query_result, query_id=query_id, page_num=1), + ] + ) + gqr_mock.assert_has_awaits( [ - call(query_id=query_id, page_num=1), + call(query_id=query_id, page_num=2), + call(query_id=query_id, page_num=3), + call(query_id=query_id, page_num=4), + ] + ) + + assert result_table.equals(table, check_metadata=True) + + +def test_sync_query_multiple_pages(mocker: MockerFixture) -> None: + """Test that querying a dataframe with multiple pages works.""" + client = SyncGraphQLClient(server_host="test", environment_id=0, auth_token="test") + + query_id = QueryId("test-query-id") + table = pa.Table.from_arrays( + [pa.array([2, 4, 6, 100]), pa.array(["Chicken", "Dog", "Ant", "Centipede"])], names=["num_legs", "animal"] + ) + + def gqr_behavior(query_id: QueryId, page_num: int) -> QueryResult: + """Behaves like `get_query_result` but without talking to any servers.""" + call_table = table.slice(offset=page_num - 1, length=1) + + byte_stream = io.BytesIO() + with pa.ipc.new_stream(byte_stream, call_table.schema) as writer: + writer.write_table(call_table) + + return QueryResult( + query_id=query_id, + status=QueryStatus.SUCCESSFUL, + sql=None, + error=None, + total_pages=len(table), + arrow_result=base64.b64encode(byte_stream.getvalue()).decode("utf-8"), + ) + + def run_behavior(op: ProtocolOperation, query_id: QueryId, page_num: int) -> QueryResult: + return gqr_behavior(query_id, page_num) + + cq_mock = mocker.patch.object(client, "create_query", return_value=query_id) + + run_mock = MagicMock(side_effect=run_behavior) + mocker.patch.object(client, "_run", new=run_mock) + gqr_mock = MagicMock(side_effect=gqr_behavior) + mocker.patch.object(client, "get_query_result", new=gqr_mock) + + gql_mock = mocker.patch.object(client, "_gql") + mocker.patch.object(gql_mock, "__aenter__") + mocker.patch("dbtsl.api.graphql.client.sync.isinstance", return_value=True) + + kwargs: QueryParameters = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1} + + with client.session(): + result_table = client.query(**kwargs) + + cq_mock.assert_called_once_with(**kwargs) + + run_mock.assert_has_calls( + [ + call(GraphQLProtocol.get_query_result, query_id=query_id, page_num=1), + ] + ) + + gqr_mock.assert_has_calls( + [ call(query_id=query_id, page_num=2), call(query_id=query_id, page_num=3), call(query_id=query_id, page_num=4),