diff --git a/.changes/unreleased/Features-20240709-103228.yaml b/.changes/unreleased/Features-20240709-103228.yaml new file mode 100644 index 0000000..27941f9 --- /dev/null +++ b/.changes/unreleased/Features-20240709-103228.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Added `saved_query` fetching via GraphQL +time: 2024-07-09T10:32:28.124763+02:00 diff --git a/.changes/unreleased/Features-20240709-103239.yaml b/.changes/unreleased/Features-20240709-103239.yaml new file mode 100644 index 0000000..24e5423 --- /dev/null +++ b/.changes/unreleased/Features-20240709-103239.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Added `entity` fetching via GraphQL +time: 2024-07-09T10:32:39.659482+02:00 diff --git a/.changes/unreleased/Features-20240709-103250.yaml b/.changes/unreleased/Features-20240709-103250.yaml new file mode 100644 index 0000000..7fe6e68 --- /dev/null +++ b/.changes/unreleased/Features-20240709-103250.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Added more fields to `dimension` +time: 2024-07-09T10:32:50.697778+02:00 diff --git a/.changes/unreleased/Features-20240709-103258.yaml b/.changes/unreleased/Features-20240709-103258.yaml new file mode 100644 index 0000000..fd8de08 --- /dev/null +++ b/.changes/unreleased/Features-20240709-103258.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Added more fields to `metric` +time: 2024-07-09T10:32:58.618167+02:00 diff --git a/.changes/unreleased/Under the Hood-20240709-103311.yaml b/.changes/unreleased/Under the Hood-20240709-103311.yaml new file mode 100644 index 0000000..6e26185 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240709-103311.yaml @@ -0,0 +1,3 @@ +kind: Under the Hood +body: Improved how GraphQL gets generated under the hood +time: 2024-07-09T10:33:11.963657+02:00 diff --git a/dbtsl/api/graphql/client/asyncio.pyi b/dbtsl/api/graphql/client/asyncio.pyi index 6f6ea31..37464e0 100644 --- a/dbtsl/api/graphql/client/asyncio.pyi +++ b/dbtsl/api/graphql/client/asyncio.pyi @@ -7,8 +7,10 @@ from typing_extensions import AsyncIterator, Unpack from dbtsl.api.shared.query_params import QueryParameters from dbtsl.models import ( Dimension, + Entity, Measure, Metric, + SavedQuery, ) class AsyncGraphQLClient: @@ -27,11 +29,19 @@ class AsyncGraphQLClient: ... async def dimensions(self, metrics: List[str]) -> List[Dimension]: - """Get a list of all available dimensions for a given metric.""" + """Get a list of all available dimensions for a given set of metrics.""" ... async def measures(self, metrics: List[str]) -> List[Measure]: - """Get a list of all available measures for a given metric.""" + """Get a list of all available measures for a given set of metrics.""" + ... + + async def entities(self, metrics: List[str]) -> List[Entity]: + """Get a list of all available entities for a given set of metrics.""" + ... + + async def saved_queries(self) -> List[SavedQuery]: + """Get a list of all available saved queries.""" ... async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ... diff --git a/dbtsl/api/graphql/client/sync.pyi b/dbtsl/api/graphql/client/sync.pyi index 8c46295..f37bb14 100644 --- a/dbtsl/api/graphql/client/sync.pyi +++ b/dbtsl/api/graphql/client/sync.pyi @@ -7,8 +7,10 @@ from typing_extensions import Self, Unpack from dbtsl.api.shared.query_params import QueryParameters from dbtsl.models import ( Dimension, + Entity, Measure, Metric, + SavedQuery, ) class SyncGraphQLClient: @@ -27,11 +29,19 @@ class SyncGraphQLClient: ... def dimensions(self, metrics: List[str]) -> List[Dimension]: - """Get a list of all available dimensions for a given metric.""" + """Get a list of all available dimensions for a given set of metrics.""" ... def measures(self, metrics: List[str]) -> List[Measure]: - """Get a list of all available measures for a given metric.""" + """Get a list of all available measures for a given set of metrics.""" + ... + + def entities(self, metrics: List[str]) -> List[Entity]: + """Get a list of all available entities for a given set of metrics.""" + ... + + def saved_queries(self) -> List[SavedQuery]: + """Get a list of all available saved queries.""" ... def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ... diff --git a/dbtsl/api/graphql/protocol.py b/dbtsl/api/graphql/protocol.py index c2ded77..08c665b 100644 --- a/dbtsl/api/graphql/protocol.py +++ b/dbtsl/api/graphql/protocol.py @@ -4,9 +4,11 @@ from mashumaro.codecs.basic import decode as decode_to_dataclass from typing_extensions import NotRequired, override +from dbtsl.api.graphql.util import render_query from dbtsl.api.shared.query_params import QueryParameters -from dbtsl.models import Dimension, Measure, Metric +from dbtsl.models import Dimension, Entity, Measure, Metric from dbtsl.models.query import QueryId, QueryResult, QueryStatus +from dbtsl.models.saved_query import SavedQuery class JobStatusVariables(TypedDict): @@ -68,13 +70,11 @@ def get_request_text(self) -> str: query = """ query getMetrics($environmentId: BigInt!) { metrics(environmentId: $environmentId) { - name - description - type + ...&fragment } } """ - return query + return render_query(query, Metric.gql_fragments()) @override def get_request_variables(self, environment_id: int, **kwargs: EmptyVariables) -> Dict[str, Any]: @@ -99,13 +99,11 @@ def get_request_text(self) -> str: query = """ query getDimensions($environmentId: BigInt!, $metrics: [MetricInput!]!) { dimensions(environmentId: $environmentId, metrics: $metrics) { - name - description - type + ...&fragment } } """ - return query + return render_query(query, Dimension.gql_fragments()) @override def get_request_variables(self, environment_id: int, **kwargs: ListEntitiesOperationVariables) -> Dict[str, Any]: @@ -127,14 +125,11 @@ def get_request_text(self) -> str: query = """ query getMeasures($environmentId: BigInt!, $metrics: [MetricInput!]!) { measures(environmentId: $environmentId, metrics: $metrics) { - name - aggTimeDimension - agg - expr + ...&fragment } } """ - return query + return render_query(query, Measure.gql_fragments()) @override def get_request_variables(self, environment_id: int, **kwargs: ListEntitiesOperationVariables) -> Dict[str, Any]: @@ -148,6 +143,55 @@ def parse_response(self, data: Dict[str, Any]) -> List[Measure]: return decode_to_dataclass(data["measures"], List[Measure]) +class ListEntitiesOperation(ProtocolOperation[ListEntitiesOperationVariables, List[Entity]]): + """List all entities for a given set of metrics.""" + + @override + def get_request_text(self) -> str: + query = """ + query getEntities($environmentId: BigInt!, $metrics: [MetricInput!]!) { + entities(environmentId: $environmentId, metrics: $metrics) { + ...&fragment + } + } + """ + return render_query(query, Entity.gql_fragments()) + + @override + def get_request_variables(self, environment_id: int, **kwargs: ListEntitiesOperationVariables) -> Dict[str, Any]: + return { + "environmentId": environment_id, + "metrics": [{"name": m} for m in kwargs["metrics"]], + } + + @override + def parse_response(self, data: Dict[str, Any]) -> List[Entity]: + return decode_to_dataclass(data["entities"], List[Entity]) + + +class ListSavedQueriesOperation(ProtocolOperation[EmptyVariables, List[SavedQuery]]): + """List all saved queries.""" + + @override + def get_request_text(self) -> str: + query = """ + query getSavedQueries($environmentId: BigInt!) { + savedQueries(environmentId: $environmentId) { + ...&fragment + } + } + """ + return render_query(query, SavedQuery.gql_fragments()) + + @override + def get_request_variables(self, environment_id: int, **kwargs: ListEntitiesOperationVariables) -> Dict[str, Any]: + return {"environmentId": environment_id} + + @override + def parse_response(self, data: Dict[str, Any]) -> List[SavedQuery]: + return decode_to_dataclass(data["savedQueries"], List[SavedQuery]) + + class CreateQueryOperation(ProtocolOperation[QueryParameters, QueryId]): """Create a query that will be processed asynchronously.""" @@ -203,16 +247,11 @@ def get_request_text(self) -> str: $pageNum: Int! ) { query(environmentId: $environmentId, queryId: $queryId, pageNum: $pageNum) { - queryId, - status, - sql, - error, - totalPages, - arrowResult + ...&fragment } } """ - return query + return render_query(query, QueryResult.gql_fragments()) @override def get_request_variables(self, environment_id: int, **kwargs: GetQueryResultVariables) -> Dict[str, Any]: @@ -237,5 +276,7 @@ class GraphQLProtocol: metrics = ListMetricsOperation() dimensions = ListDimensionsOperation() measures = ListMeasuresOperation() + entities = ListEntitiesOperation() + saved_queries = ListSavedQueriesOperation() create_query = CreateQueryOperation() get_query_result = GetQueryResultOperation() diff --git a/dbtsl/api/graphql/util.py b/dbtsl/api/graphql/util.py new file mode 100644 index 0000000..ca58c4e --- /dev/null +++ b/dbtsl/api/graphql/util.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import re +from string import Template +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from dbtsl.models.base import GraphQLFragment + +query_sub_pat = re.compile(r"[ \t\n]+") + + +def normalize_query(s: str) -> str: + """Return a normalized query string. + + This strips newlines, too many whitespaces etc so we can + make assertions that queries equal each other regarless of indentation. + """ + return query_sub_pat.subn(" ", s.strip("\n"))[0].strip() + + +class QueryTemplate(Template): + """Subclass Template since $ is reserved in GraphQL.""" + + delimiter = "&" + + +def render_query(template_str: str, dependencies: List[GraphQLFragment]) -> str: + """Return a rendered query from a template and its fragment dependencies. + + The template must have a &fragment which indicates where the main + fragment should be replaced in the query. + + The main fragment will be dependencies[0]. + """ + template = QueryTemplate(template_str) + assert len(dependencies) > 0 + template_render = normalize_query(template.substitute(fragment=dependencies[0].name)) + segments = [template_render] + [normalize_query(frag.body) for frag in dependencies] + return " ".join(segments) diff --git a/dbtsl/client/asyncio.pyi b/dbtsl/client/asyncio.pyi index a033edb..b5705a5 100644 --- a/dbtsl/client/asyncio.pyi +++ b/dbtsl/client/asyncio.pyi @@ -5,7 +5,7 @@ import pyarrow as pa from typing_extensions import Self, Unpack from dbtsl.api.adbc.protocol import QueryParameters -from dbtsl.models import Dimension, Measure, Metric +from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery class AsyncSemanticLayerClient: def __init__( @@ -30,6 +30,14 @@ class AsyncSemanticLayerClient: """List all the measures available for a given set of metrics.""" ... + async def entities(self, metrics: List[str]) -> List[Entity]: + """Get a list of all available entities for a given set of metrics.""" + ... + + async def saved_queries(self) -> List[SavedQuery]: + """Get a list of all available saved queries.""" + ... + def session(self) -> AbstractAsyncContextManager[AsyncIterator[Self]]: """Establish a connection with the dbt Semantic Layer's servers.""" ... diff --git a/dbtsl/client/sync.pyi b/dbtsl/client/sync.pyi index 2727d5d..9862cdc 100644 --- a/dbtsl/client/sync.pyi +++ b/dbtsl/client/sync.pyi @@ -5,7 +5,7 @@ import pyarrow as pa from typing_extensions import Self, Unpack from dbtsl.api.adbc.protocol import QueryParameters -from dbtsl.models import Dimension, Measure, Metric +from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery class SyncSemanticLayerClient: def __init__( @@ -30,6 +30,14 @@ class SyncSemanticLayerClient: """List all the measures available for a given set of metrics.""" ... + def entities(self, metrics: List[str]) -> List[Entity]: + """Get a list of all available entities for a given set of metrics.""" + ... + + async def saved_queries(self) -> List[SavedQuery]: + """Get a list of all available saved queries.""" + ... + def session(self) -> AbstractContextManager[Iterator[Self]]: """Establish a connection with the dbt Semantic Layer's servers.""" ... diff --git a/dbtsl/models/__init__.py b/dbtsl/models/__init__.py index 70f97b8..4453ec6 100644 --- a/dbtsl/models/__init__.py +++ b/dbtsl/models/__init__.py @@ -6,16 +6,27 @@ from .base import BaseModel from .dimension import Dimension, DimensionType +from .entity import Entity, EntityType from .measure import AggregationType, Measure from .metric import Metric, MetricType +from .query import QueryResult +from .saved_query import SavedQuery +from .time_granularity import TimeGranularity + +# Only importing this so it registers aliases +_ = QueryResult BaseModel._apply_aliases() __all__ = [ + "AggregationType", "Dimension", "DimensionType", + "Entity", + "EntityType", "Measure", - "AggregationType", "Metric", "MetricType", + "SavedQuery", + "TimeGranularity", ] diff --git a/dbtsl/models/base.py b/dbtsl/models/base.py index c3eda0b..5b05b01 100644 --- a/dbtsl/models/base.py +++ b/dbtsl/models/base.py @@ -1,9 +1,17 @@ -from dataclasses import fields, is_dataclass +import inspect +from dataclasses import dataclass, fields, is_dataclass +from dataclasses import field as dc_field +from functools import cache from types import MappingProxyType +from typing import List, Set, Type, Union +from typing import get_args as get_type_args +from typing import get_origin as get_type_origin from mashumaro import DataClassDictMixin, field_options from mashumaro.config import BaseConfig +from dbtsl.api.graphql.util import normalize_query + def snake_case_to_camel_case(s: str) -> str: """Convert a snake_case_string into a camelCaseString.""" @@ -30,3 +38,71 @@ def _apply_aliases(cls) -> None: camel_name = snake_case_to_camel_case(field.name) if field.name != camel_name: field.metadata = MappingProxyType(field_options(alias=camel_name)) + + +@dataclass(frozen=True, eq=True) +class GraphQLFragment: + """Represent a model as a GraphQL fragment.""" + + name: str + body: str = dc_field(hash=False) + + +class GraphQLFragmentMixin: + """Add this to any model that needs to be fetched from GraphQL.""" + + @classmethod + def gql_model_name(cls) -> str: + """The model's name in the GraphQL schema. Defaults to same as class name.""" + return cls.__name__ + + # NOTE: this will overflow the stack if we add any circular dependencies in our GraphQL schema, like + # Metric -> Dimension -> Metric -> Dimension ... + # + # If we do that, we need to modify this method to memoize what fragments were already created + # so that we exit the recursion gracefully + @staticmethod + def _get_fragments_for_field(type: Type, field_name: str) -> Union[str, List[GraphQLFragment]]: + if inspect.isclass(type) and issubclass(type, GraphQLFragmentMixin): + return type.gql_fragments() + + if get_type_origin(type) == list: + inner_type = get_type_args(type)[0] + return GraphQLFragmentMixin._get_fragments_for_field(inner_type, field_name) + + return snake_case_to_camel_case(field_name) + + @classmethod + @cache + def gql_fragments(cls) -> List[GraphQLFragment]: + """Get the GraphQL fragments needed to query for this model. + + The first (0th) fragment is always the fragment that represents the model itself. + The remaining fragments are dependencies of the model, if any. + """ + gql_model_name = cls.gql_model_name() + fragment_name = f"fragment{cls.__name__}" + + assert is_dataclass(cls), "Subclass of GraphQLFragmentMixin must be dataclass" + + query_elements: List[str] = [] + dependencies: Set[GraphQLFragment] = set() + for field in fields(cls): + frag_or_field = GraphQLFragmentMixin._get_fragments_for_field(field.type, field.name) + if isinstance(frag_or_field, str): + query_elements.append(frag_or_field) + else: + frag = frag_or_field[0] + field_query = snake_case_to_camel_case(field.name) + " { ..." + frag.name + " }" + query_elements.append(field_query) + dependencies.update(frag_or_field) + + query_str = " \n".join(query_elements) + + fragment_body = normalize_query(f""" + fragment {fragment_name} on {gql_model_name} {{ + {query_str} + }} + """) + fragment = GraphQLFragment(name=fragment_name, body=fragment_body) + return [fragment] + list(dependencies) diff --git a/dbtsl/models/dimension.py b/dbtsl/models/dimension.py index 54bc72f..898ce7b 100644 --- a/dbtsl/models/dimension.py +++ b/dbtsl/models/dimension.py @@ -1,7 +1,9 @@ from dataclasses import dataclass from enum import Enum +from typing import List, Optional -from dbtsl.models.base import BaseModel +from dbtsl.models.base import BaseModel, GraphQLFragmentMixin +from dbtsl.models.time_granularity import TimeGranularity class DimensionType(str, Enum): @@ -12,9 +14,14 @@ class DimensionType(str, Enum): @dataclass(frozen=True) -class Dimension(BaseModel): +class Dimension(BaseModel, GraphQLFragmentMixin): """A metric dimension.""" name: str - description: str + qualified_name: str + description: Optional[str] type: DimensionType + label: Optional[str] + is_partition: bool + expr: Optional[str] + queryable_granularities: List[TimeGranularity] diff --git a/dbtsl/models/entity.py b/dbtsl/models/entity.py new file mode 100644 index 0000000..d482cd5 --- /dev/null +++ b/dbtsl/models/entity.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from dbtsl.models.base import BaseModel, GraphQLFragmentMixin + + +class EntityType(str, Enum): + """All supported entity types.""" + + FOREIGN = "FOREIGN" + NATURAL = "NATURAL" + PRIMARY = "PRIMARY" + UNIQUE = "UNIQUE" + + +@dataclass(frozen=True) +class Entity(BaseModel, GraphQLFragmentMixin): + """An entity.""" + + name: str + description: Optional[str] + type: EntityType + role: str + expr: str diff --git a/dbtsl/models/measure.py b/dbtsl/models/measure.py index 57d0f50..850f715 100644 --- a/dbtsl/models/measure.py +++ b/dbtsl/models/measure.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Optional -from dbtsl.models.base import BaseModel +from dbtsl.models.base import BaseModel, GraphQLFragmentMixin class AggregationType(str, Enum): @@ -20,7 +20,7 @@ class AggregationType(str, Enum): @dataclass(frozen=True) -class Measure(BaseModel): +class Measure(BaseModel, GraphQLFragmentMixin): """A measure.""" name: str diff --git a/dbtsl/models/metric.py b/dbtsl/models/metric.py index 7e2620c..6ce2035 100644 --- a/dbtsl/models/metric.py +++ b/dbtsl/models/metric.py @@ -1,7 +1,12 @@ from dataclasses import dataclass from enum import Enum +from typing import List, Optional -from dbtsl.models.base import BaseModel +from dbtsl.models.base import BaseModel, GraphQLFragmentMixin +from dbtsl.models.dimension import Dimension +from dbtsl.models.entity import Entity +from dbtsl.models.measure import Measure +from dbtsl.models.time_granularity import TimeGranularity class MetricType(str, Enum): @@ -12,22 +17,18 @@ class MetricType(str, Enum): CUMULATIVE = "CUMULATIVE" DERIVED = "DERIVED" CONVERSION = "CONVERSION" - UNKNOWN = "UNKNOWN" - - @classmethod - def missing(cls, _: str) -> "MetricType": - """Return UNKNOWN by default. - - Prevents client from breaking in case a new unknown type is introduced - by the server. - """ - return cls.UNKNOWN @dataclass(frozen=True) -class Metric(BaseModel): +class Metric(BaseModel, GraphQLFragmentMixin): """A metric.""" name: str - description: str + description: Optional[str] type: MetricType + dimensions: List[Dimension] + measures: List[Measure] + entities: List[Entity] + queryable_granularities: List[TimeGranularity] + label: str + requires_metric_time: bool diff --git a/dbtsl/models/query.py b/dbtsl/models/query.py index cd53342..2f86188 100644 --- a/dbtsl/models/query.py +++ b/dbtsl/models/query.py @@ -6,7 +6,7 @@ import pyarrow as pa -from dbtsl.models.base import BaseModel +from dbtsl.models.base import BaseModel, GraphQLFragmentMixin QueryId = NewType("QueryId", str) @@ -22,7 +22,7 @@ class QueryStatus(str, Enum): @dataclass(frozen=True) -class QueryResult(BaseModel): +class QueryResult(BaseModel, GraphQLFragmentMixin): """A query result containing its status, SQL and error/results.""" query_id: QueryId diff --git a/dbtsl/models/saved_query.py b/dbtsl/models/saved_query.py new file mode 100644 index 0000000..41e390e --- /dev/null +++ b/dbtsl/models/saved_query.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import Optional + +from dbtsl.models.base import BaseModel, GraphQLFragmentMixin + + +@dataclass(frozen=True) +class SavedQuery(BaseModel, GraphQLFragmentMixin): + """A saved query.""" + + name: str + description: Optional[str] + label: Optional[str] diff --git a/dbtsl/models/time_granularity.py b/dbtsl/models/time_granularity.py new file mode 100644 index 0000000..3891083 --- /dev/null +++ b/dbtsl/models/time_granularity.py @@ -0,0 +1,17 @@ +from enum import Enum + + +class TimeGranularity(str, Enum): + """A time granularity.""" + + NANOSECOND = "NANOSECOND" + MICROSECOND = "MICROSECOND" + MILLISECOND = "MILLISECOND" + SECOND = "SECOND" + MINUTE = "MINUTE" + HOUR = "HOUR" + DAY = "DAY" + WEEK = "WEEK" + MONTH = "MONTH" + QUARTER = "QUARTER" + YEAR = "YEAR" diff --git a/examples/list_metrics_sync.py b/examples/list_metrics_sync.py index 9dda6b7..95256bc 100644 --- a/examples/list_metrics_sync.py +++ b/examples/list_metrics_sync.py @@ -32,6 +32,21 @@ def main() -> None: print(f" type={m.type}") print(f" description={m.description}") + print(" dimensions=[") + for dim in m.dimensions: + print(f" {dim.name},") + print(" ]") + + print(" measures=[") + for measure in m.measures: + print(f" {measure.name},") + print(" ]") + + print(" entities=[") + for entity in m.entities: + print(f" {entity.name},") + print(" ]") + if __name__ == "__main__": main() diff --git a/examples/list_saved_queries_async.py b/examples/list_saved_queries_async.py new file mode 100644 index 0000000..b993a95 --- /dev/null +++ b/examples/list_saved_queries_async.py @@ -0,0 +1,38 @@ +"""Fetch all available saved queries from the metadata API and display them.""" + +import asyncio +from argparse import ArgumentParser + +from dbtsl.asyncio import AsyncSemanticLayerClient + + +def get_arg_parser() -> ArgumentParser: + p = ArgumentParser() + + p.add_argument("--env-id", required=True, help="The dbt environment ID", type=int) + p.add_argument("--token", required=True, help="The API auth token") + p.add_argument("--host", required=True, help="The API host") + + return p + + +async def main() -> None: + arg_parser = get_arg_parser() + args = arg_parser.parse_args() + + client = AsyncSemanticLayerClient( + environment_id=args.env_id, + auth_token=args.token, + host=args.host, + ) + + async with client.session(): + saved_queries = await client.saved_queries() + for sq in saved_queries: + print(f"{sq.name}:") + print(f" label: {sq.label}") + print(f" description: {sq.description}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/api/graphql/test_util.py b/tests/api/graphql/test_util.py new file mode 100644 index 0000000..7ae7dcd --- /dev/null +++ b/tests/api/graphql/test_util.py @@ -0,0 +1,63 @@ +from dbtsl.api.graphql.util import normalize_query, render_query +from dbtsl.models.base import GraphQLFragment + + +def test_normalize_query() -> None: + q = """ + + myQuery { + foo { + baz + bar + } } + """ + + assert normalize_query(q) == "myQuery { foo { baz bar } }" + + +def test_render_query() -> None: + template = """ + myQuery { + ...&fragment + } + """ + dependencies = [ + GraphQLFragment( + name="mainFrag", + body=""" + fragment mainFrag on Test { + foo + bar + dep { + ...depFrag + } + } + """, + ), + GraphQLFragment( + name="depFrag", + body=""" + fragment depFrag on Dep { + baz + } + """, + ), + ] + + expect = normalize_query(""" + myQuery { + ...mainFrag + } + fragment mainFrag on Test { + foo + bar + dep { + ...depFrag + } + } + fragment depFrag on Dep { + baz + } + """) + rendered = render_query(template, dependencies) + assert normalize_query(expect) == rendered diff --git a/tests/integration/test_gql.py b/tests/integration/test_gql.py index 321073b..aba9407 100644 --- a/tests/integration/test_gql.py +++ b/tests/integration/test_gql.py @@ -30,18 +30,40 @@ def sync_client(credentials: Credentials) -> Iterator[SyncGraphQLClient]: yield client -def test_sync_client_lists_metrics_and_dimensions(sync_client: SyncGraphQLClient) -> None: +def test_sync_client_lists_metrics_dimensions_entities(sync_client: SyncGraphQLClient) -> None: metrics = sync_client.metrics() assert len(metrics) > 0 + dims = sync_client.dimensions(metrics=[metrics[0].name]) assert len(dims) > 0 + assert dims == metrics[0].dimensions + + entities = sync_client.entities(metrics=[metrics[0].name]) + assert len(entities) > 0 + assert entities == metrics[0].entities -async def test_async_client_lists_metrics_and_dimensions(async_client: AsyncGraphQLClient) -> None: +async def test_async_client_lists_metrics_dimensions_entities(async_client: AsyncGraphQLClient) -> None: metrics = await async_client.metrics() assert len(metrics) > 0 + dims = await async_client.dimensions(metrics=[metrics[0].name]) assert len(dims) > 0 + assert dims == metrics[0].dimensions + + entities = await async_client.entities(metrics=[metrics[0].name]) + assert len(entities) > 0 + assert entities == metrics[0].entities + + +def test_sync_client_lists_saved_queries(sync_client: SyncGraphQLClient) -> None: + sqs = sync_client.saved_queries() + assert len(sqs) > 0 + + +async def test_async_client_lists_saved_queries(async_client: AsyncGraphQLClient) -> None: + sqs = await async_client.saved_queries() + assert len(sqs) > 0 def test_sync_client_query_works(sync_client: SyncGraphQLClient) -> None: diff --git a/tests/test_models.py b/tests/test_models.py index dc3653b..6aab82b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,8 +1,10 @@ from dataclasses import dataclass +from typing import List from mashumaro.codecs.basic import decode -from dbtsl.models.base import BaseModel +from dbtsl.api.graphql.util import normalize_query +from dbtsl.models.base import BaseModel, GraphQLFragmentMixin from dbtsl.models.base import snake_case_to_camel_case as stc @@ -30,3 +32,48 @@ class SubModel(BaseModel): codec_model = decode(data, SubModel) assert codec_model.hello_world == "asdf" + + +def test_graphql_fragment_mixin() -> None: + @dataclass + class A(BaseModel, GraphQLFragmentMixin): + foo_bar: str + + @dataclass + class B(BaseModel, GraphQLFragmentMixin): + hello_world: str + baz: str + a: A + many_a: List[A] + + a_fragments = A.gql_fragments() + assert len(a_fragments) == 1 + a_fragment = a_fragments[0] + + a_expect = normalize_query(""" + fragment fragmentA on A { + fooBar + } + """) + assert a_fragment.name == "fragmentA" + assert a_fragment.body == a_expect + + b_fragments = B.gql_fragments() + assert len(b_fragments) == 2 + b_fragment = b_fragments[0] + + b_expect = normalize_query(""" + fragment fragmentB on B { + helloWorld + baz + a { + ...fragmentA + } + manyA { + ...fragmentA + } + } + """) + assert b_fragment.name == "fragmentB" + assert b_fragment.body == b_expect + assert b_fragments[1] == a_fragment