Skip to content

Commit

Permalink
feat: add saved queries
Browse files Browse the repository at this point in the history
This commit adds the possibility to list saved queries from the GraphQL
API.
  • Loading branch information
serramatutu committed Jul 9, 2024
1 parent 2f4889f commit abf305c
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 8 deletions.
11 changes: 8 additions & 3 deletions dbtsl/api/graphql/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from dbtsl.models import (
Entity,
Measure,
Metric,
SavedQuery,
)

class AsyncGraphQLClient:
Expand All @@ -28,15 +29,19 @@ class AsyncGraphQLClient:
...

async def dimensions(self, metrics: List[str]) -> List[Dimension]:
"""Get a list of all available dimensions for a given metric."""
"""Get a list of all available dimensions for a given set of metrics."""
...

async def measures(self, metrics: List[str]) -> List[Measure]:
"""Get a list of all available measures for a given metric."""
"""Get a list of all available measures for a given set of metrics."""
...

async def entities(self, metrics: List[str]) -> List[Entity]:
"""Get a list of all available entities for a given metric."""
"""Get a list of all available entities for a given set of metrics."""
...

async def saved_queries(self) -> List[SavedQuery]:
"""Get a list of all available saved queries."""
...

async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
11 changes: 8 additions & 3 deletions dbtsl/api/graphql/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from dbtsl.models import (
Entity,
Measure,
Metric,
SavedQuery,
)

class SyncGraphQLClient:
Expand All @@ -28,15 +29,19 @@ class SyncGraphQLClient:
...

def dimensions(self, metrics: List[str]) -> List[Dimension]:
"""Get a list of all available dimensions for a given metric."""
"""Get a list of all available dimensions for a given set of metrics."""
...

def measures(self, metrics: List[str]) -> List[Measure]:
"""Get a list of all available measures for a given metric."""
"""Get a list of all available measures for a given set of metrics."""
...

def entities(self, metrics: List[str]) -> List[Entity]:
"""Get a list of all available entities for a given metric."""
"""Get a list of all available entities for a given set of metrics."""
...

def saved_queries(self) -> List[SavedQuery]:
"""Get a list of all available saved queries."""
...

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
25 changes: 25 additions & 0 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.models import Dimension, Entity, Measure, Metric
from dbtsl.models.query import QueryId, QueryResult, QueryStatus
from dbtsl.models.saved_query import SavedQuery


class JobStatusVariables(TypedDict):
Expand Down Expand Up @@ -168,6 +169,29 @@ def parse_response(self, data: Dict[str, Any]) -> List[Entity]:
return decode_to_dataclass(data["entities"], List[Entity])


class ListSavedQueriesOperation(ProtocolOperation[EmptyVariables, List[SavedQuery]]):
"""List all saved queries."""

@override
def get_request_text(self) -> str:
query = """
query getSavedQueries($environmentId: BigInt!) {
savedQueries(environmentId: $environmentId) {
...&fragment
}
}
"""
return render_query(query, SavedQuery.gql_fragments())

@override
def get_request_variables(self, environment_id: int, **kwargs: ListEntitiesOperationVariables) -> Dict[str, Any]:
return {"environmentId": environment_id}

@override
def parse_response(self, data: Dict[str, Any]) -> List[SavedQuery]:
return decode_to_dataclass(data["savedQueries"], List[SavedQuery])


class CreateQueryOperation(ProtocolOperation[QueryParameters, QueryId]):
"""Create a query that will be processed asynchronously."""

Expand Down Expand Up @@ -253,5 +277,6 @@ class GraphQLProtocol:
dimensions = ListDimensionsOperation()
measures = ListMeasuresOperation()
entities = ListEntitiesOperation()
saved_queries = ListSavedQueriesOperation()
create_query = CreateQueryOperation()
get_query_result = GetQueryResultOperation()
10 changes: 9 additions & 1 deletion dbtsl/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import pyarrow as pa
from typing_extensions import Self, Unpack

from dbtsl.api.adbc.protocol import QueryParameters
from dbtsl.models import Dimension, Measure, Metric
from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery

class AsyncSemanticLayerClient:
def __init__(
Expand All @@ -30,6 +30,14 @@ class AsyncSemanticLayerClient:
"""List all the measures available for a given set of metrics."""
...

async def entities(self, metrics: List[str]) -> List[Entity]:
"""Get a list of all available entities for a given set of metrics."""
...

async def saved_queries(self) -> List[SavedQuery]:
"""Get a list of all available saved queries."""
...

def session(self) -> AbstractAsyncContextManager[AsyncIterator[Self]]:
"""Establish a connection with the dbt Semantic Layer's servers."""
...
10 changes: 9 additions & 1 deletion dbtsl/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import pyarrow as pa
from typing_extensions import Self, Unpack

from dbtsl.api.adbc.protocol import QueryParameters
from dbtsl.models import Dimension, Measure, Metric
from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery

class SyncSemanticLayerClient:
def __init__(
Expand All @@ -30,6 +30,14 @@ class SyncSemanticLayerClient:
"""List all the measures available for a given set of metrics."""
...

def entities(self, metrics: List[str]) -> List[Entity]:
"""Get a list of all available entities for a given set of metrics."""
...

async def saved_queries(self) -> List[SavedQuery]:
"""Get a list of all available saved queries."""
...

def session(self) -> AbstractContextManager[Iterator[Self]]:
"""Establish a connection with the dbt Semantic Layer's servers."""
...
2 changes: 2 additions & 0 deletions dbtsl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .measure import AggregationType, Measure
from .metric import Metric, MetricType
from .query import QueryResult
from .saved_query import SavedQuery
from .time_granularity import TimeGranularity

# Only importing this so it registers aliases
Expand All @@ -26,5 +27,6 @@
"Measure",
"Metric",
"MetricType",
"SavedQuery",
"TimeGranularity",
]
13 changes: 13 additions & 0 deletions dbtsl/models/saved_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass
from typing import Optional

from dbtsl.models.base import BaseModel, GraphQLFragmentMixin


@dataclass(frozen=True)
class SavedQuery(BaseModel, GraphQLFragmentMixin):
"""A saved query."""

name: str
description: Optional[str]
label: Optional[str]
38 changes: 38 additions & 0 deletions examples/list_saved_queries_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Fetch all available saved queries from the metadata API and display them."""

import asyncio
from argparse import ArgumentParser

from dbtsl.asyncio import AsyncSemanticLayerClient


def get_arg_parser() -> ArgumentParser:
p = ArgumentParser()

p.add_argument("--env-id", required=True, help="The dbt environment ID", type=int)
p.add_argument("--token", required=True, help="The API auth token")
p.add_argument("--host", required=True, help="The API host")

return p


async def main() -> None:
arg_parser = get_arg_parser()
args = arg_parser.parse_args()

client = AsyncSemanticLayerClient(
environment_id=args.env_id,
auth_token=args.token,
host=args.host,
)

async with client.session():
saved_queries = await client.saved_queries()
for sq in saved_queries:
print(f"{sq.name}:")
print(f" label: {sq.label}")
print(f" description: {sq.description}")


if __name__ == "__main__":
asyncio.run(main())
10 changes: 10 additions & 0 deletions tests/integration/test_gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ async def test_async_client_lists_metrics_dimensions_entities(async_client: Asyn
assert entities == metrics[0].entities


def test_sync_client_lists_saved_queries(sync_client: SyncGraphQLClient) -> None:
sqs = sync_client.saved_queries()
assert len(sqs) > 0


async def test_async_client_lists_saved_queries(async_client: AsyncGraphQLClient) -> None:
sqs = await async_client.saved_queries()
assert len(sqs) > 0


def test_sync_client_query_works(sync_client: SyncGraphQLClient) -> None:
metrics = sync_client.metrics()
assert len(metrics) > 0
Expand Down

0 comments on commit abf305c

Please sign in to comment.