Skip to content

Commit

Permalink
feat: add query() to sync client (#18)
Browse files Browse the repository at this point in the history
This commit adds the ability to query the semantic layer using the sync
GraphQL client. The code is largely copied from the async code, but it's
very hard to generalize and would become sort of unreadable. If this
becomes a larger theme, we can think of how to refactor later.
  • Loading branch information
serramatutu authored Jun 19, 2024
1 parent 8480af8 commit 238cece
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20240619-142416.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: Add `query()` to sync client
time: 2024-06-19T14:24:16.542965+02:00
7 changes: 2 additions & 5 deletions dbtsl/api/graphql/client/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from dbtsl.api.graphql.client.base import BaseGraphQLClient
from dbtsl.api.graphql.protocol import (
GetQueryResultVariables,
GraphQLProtocol,
ProtocolOperation,
TResponse,
TVariables,
Expand Down Expand Up @@ -76,11 +75,11 @@ async def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVa

async def _create_query(self, **params: Unpack[QueryParameters]) -> QueryId:
"""Create a query that will run asynchronously."""
return await self._run(GraphQLProtocol.create_query, **params) # type: ignore
return await self._run(self.PROTOCOL.create_query, **params) # type: ignore

async def _get_query_result(self, **params: Unpack[GetQueryResultVariables]) -> QueryResult:
"""Fetch a query's results'."""
return await self._run(GraphQLProtocol.get_query_result, **params) # type: ignore
return await self._run(self.PROTOCOL.get_query_result, **params) # type: ignore

async def _poll_until_complete(
self,
Expand Down Expand Up @@ -114,8 +113,6 @@ async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
if first_page_results.status != QueryStatus.SUCCESSFUL:
raise QueryFailedError()

# Server should never return None if query is SUCCESSFUL.
# This is so pyright stops complaining
assert first_page_results.total_pages is not None

if first_page_results.total_pages == 1:
Expand Down
64 changes: 61 additions & 3 deletions dbtsl/api/graphql/client/sync.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import time
from contextlib import contextmanager
from typing import Dict, Iterator, Optional

import pyarrow as pa
from gql import gql
from gql.client import SyncClientSession
from gql.transport.requests import RequestsHTTPTransport
from typing_extensions import Self, override
from typing_extensions import Self, Unpack, override

from dbtsl.api.graphql.client.base import BaseGraphQLClient
from dbtsl.api.graphql.protocol import (
GetQueryResultVariables,
ProtocolOperation,
TResponse,
TVariables,
)
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.backoff import ExponentialBackoff
from dbtsl.error import QueryFailedError
from dbtsl.models.query import QueryId, QueryResult, QueryStatus


class SyncGraphQLClient(BaseGraphQLClient[RequestsHTTPTransport, SyncClientSession]):
Expand Down Expand Up @@ -66,5 +73,56 @@ def _run(self, op: ProtocolOperation[TVariables, TResponse], **kwargs: TVariable

return op.parse_response(res)

# TODO: sync transport doesn't have `query` method. This should be OK since ADBC
# is the go-to method anyways. If people request it, we can implement later.
def _create_query(self, **params: Unpack[QueryParameters]) -> QueryId:
"""Create a query that will run asynchronously."""
return self._run(self.PROTOCOL.create_query, **params) # type: ignore

def _get_query_result(self, **params: Unpack[GetQueryResultVariables]) -> QueryResult:
"""Fetch a query's results'."""
return self._run(self.PROTOCOL.get_query_result, **params) # type: ignore

def _poll_until_complete(
self,
query_id: QueryId,
backoff: Optional[ExponentialBackoff] = None,
) -> QueryResult:
"""Poll for a query's results until it is in a completed state (SUCCESSFUL or FAILED).
Note that this function does NOT fetch all pages in case the query is SUCCESSFUL. It only
returns once the query is done. Callers must implement this logic themselves.
"""
if backoff is None:
backoff = self._default_backoff()

for sleep_ms in backoff.iter_ms():
# TODO: add timeout param to all requests because technically the API could hang and
# then we don't respect timeout.
qr = self._get_query_result(query_id=query_id, page_num=1)
if qr.status in (QueryStatus.SUCCESSFUL, QueryStatus.FAILED):
return qr

time.sleep(sleep_ms / 1000)

# This should be unreachable
raise ValueError()

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
query_id = self._create_query(**params)
first_page_results = self._poll_until_complete(query_id)
if first_page_results.status != QueryStatus.SUCCESSFUL:
raise QueryFailedError()

assert first_page_results.total_pages is not None

if first_page_results.total_pages == 1:
return first_page_results.result_table

results = [
self._get_query_result(query_id=query_id, page_num=page)
for page in range(2, first_page_results.total_pages + 1)
]
all_page_results = [first_page_results] + results
tables = [r.result_table for r in all_page_results]
final_table = pa.concat_tables(tables)
return final_table
6 changes: 5 additions & 1 deletion dbtsl/api/graphql/client/sync.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from contextlib import AbstractContextManager
from typing import Iterator, List, Optional

from typing_extensions import Self
import pyarrow as pa
from typing_extensions import Self, Unpack

from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.models import (
Dimension,
Measure,
Expand Down Expand Up @@ -31,3 +33,5 @@ class SyncGraphQLClient:
def measures(self, metrics: List[str]) -> List[Measure]:
"""Get a list of all available measures for a given metric."""
...

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
11 changes: 11 additions & 0 deletions tests/integration/test_gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ async def test_async_client_lists_metrics_and_dimensions(async_client: AsyncGrap
assert len(dims) > 0


def test_sync_client_query_works(sync_client: SyncGraphQLClient) -> None:
metrics = sync_client.metrics()
assert len(metrics) > 0
table = sync_client.query(
metrics=[metrics[0].name],
group_by=["metric_time"],
limit=1,
)
assert len(table) > 0


async def test_async_client_query_works(async_client: AsyncGraphQLClient) -> None:
metrics = await async_client.metrics()
assert len(metrics) > 0
Expand Down

0 comments on commit 238cece

Please sign in to comment.