From 9896a1c4857fd298d2a300d50e252536d7c2b54f Mon Sep 17 00:00:00 2001 From: Lucas Valente Date: Wed, 16 Oct 2024 16:09:57 +0200 Subject: [PATCH] fix: `order_by` for `compile_sql` now works as expected (#49) * test: all parameters in `query` and `compile_sql` The tests for `query` and `compile_sql` did not test all allowed parameters, which made the `order_by` bug go unnoticed. This commit fixes that by adding all the remaining parameters to the tests. Note that this commit is in a broken state since the fix hasn't been applied yet. That will come in a future patch. * refactor: better internal repr of query params This commit improves our validation and representation of query parameters, and fixes the bug with `order_by`. We're still in an inconsistent state: we gotta propagate the changes and use them in other classes. Will do that in the following patch. * refactor: make protocols use the new validations This commit makes the ADBC and GraphQL protocol implementations use the new stricter representation for query params. It updates the tests accordingly. * refactor: typing public interfaces with order by This commit updates the public `.pyi` files to ensure we use the new order by spec * docs: changelog Added changelog entries related to the order by changes. --- .../Breaking Changes-20241001-160946.yaml | 3 + .../unreleased/Features-20241001-155522.yaml | 3 + .../unreleased/Fixes-20241001-155448.yaml | 3 + .../Under the Hood-20241001-155544.yaml | 3 + dbtsl/__init__.py | 4 +- dbtsl/api/adbc/protocol.py | 57 +++++++-- dbtsl/api/graphql/client/asyncio.pyi | 12 +- dbtsl/api/graphql/client/sync.pyi | 12 +- dbtsl/api/graphql/protocol.py | 67 ++++++---- dbtsl/api/shared/query_params.py | 115 +++++++++++++++++- dbtsl/client/asyncio.pyi | 12 +- dbtsl/client/sync.pyi | 12 +- tests/api/adbc/test_protocol.py | 61 +++++++--- tests/integration/test_sl_client.py | 13 ++ tests/test_models.py | 96 +++++++++++++-- 15 files changed, 380 insertions(+), 93 deletions(-) create mode 100644 .changes/unreleased/Breaking Changes-20241001-160946.yaml create mode 100644 .changes/unreleased/Features-20241001-155522.yaml create mode 100644 .changes/unreleased/Fixes-20241001-155448.yaml create mode 100644 .changes/unreleased/Under the Hood-20241001-155544.yaml diff --git a/.changes/unreleased/Breaking Changes-20241001-160946.yaml b/.changes/unreleased/Breaking Changes-20241001-160946.yaml new file mode 100644 index 0000000..367cf2d --- /dev/null +++ b/.changes/unreleased/Breaking Changes-20241001-160946.yaml @@ -0,0 +1,3 @@ +kind: Breaking Changes +body: '`order_by` clause of queries using saved queries no longer support string inputs and require explicit `OrderByMetric` or `OrderByGroupBy`' +time: 2024-10-01T16:09:46.752151+02:00 diff --git a/.changes/unreleased/Features-20241001-155522.yaml b/.changes/unreleased/Features-20241001-155522.yaml new file mode 100644 index 0000000..f1545a9 --- /dev/null +++ b/.changes/unreleased/Features-20241001-155522.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Specifying order by with `OrderByMetric` and `OrderByGroupBy` +time: 2024-10-01T15:55:22.041799+02:00 diff --git a/.changes/unreleased/Fixes-20241001-155448.yaml b/.changes/unreleased/Fixes-20241001-155448.yaml new file mode 100644 index 0000000..0f01686 --- /dev/null +++ b/.changes/unreleased/Fixes-20241001-155448.yaml @@ -0,0 +1,3 @@ +kind: Fixes +body: Order by for `compile_sql` now works as expected +time: 2024-10-01T15:54:48.740022+02:00 diff --git a/.changes/unreleased/Under the Hood-20241001-155544.yaml b/.changes/unreleased/Under the Hood-20241001-155544.yaml new file mode 100644 index 0000000..55d9b80 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241001-155544.yaml @@ -0,0 +1,3 @@ +kind: Under the Hood +body: Improved internal representation of query parameters and added better validation +time: 2024-10-01T15:55:44.855697+02:00 diff --git a/dbtsl/__init__.py b/dbtsl/__init__.py index 5587aa0..6ea742f 100644 --- a/dbtsl/__init__.py +++ b/dbtsl/__init__.py @@ -13,4 +13,6 @@ def err_factory(*args, **kwargs) -> None: # noqa: D103 SemanticLayerClient = err_factory -__all__ = ["SemanticLayerClient"] +from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric + +__all__ = ["SemanticLayerClient", "OrderByMetric", "OrderByGroupBy"] diff --git a/dbtsl/api/adbc/protocol.py b/dbtsl/api/adbc/protocol.py index 0a89280..c49d176 100644 --- a/dbtsl/api/adbc/protocol.py +++ b/dbtsl/api/adbc/protocol.py @@ -1,25 +1,55 @@ +import dataclasses import json -from typing import Any, FrozenSet, Mapping +from typing import Any, List, Mapping -from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters, validate_query_parameters +from dbtsl.api.shared.query_params import ( + DimensionValuesQueryParameters, + OrderByGroupBy, + OrderByMetric, + QueryParameters, + validate_query_parameters, +) class ADBCProtocol: """The protocol for the Arrow Flight dataframe API.""" - @staticmethod - def _serialize_params_dict(params: Mapping[str, Any], param_names: FrozenSet[str]) -> str: + @classmethod + def _serialize_val(cls, val: Any) -> str: + if isinstance(val, bool): + return str(val) + + if isinstance(val, list): + list_str = ",".join(cls._serialize_val(list_val) for list_val in val) + return f"[{list_str}]" + + if isinstance(val, OrderByMetric): + m = f'Metric("{val.name}")' + if val.descending: + m += ".descending(True)" + return m + + if isinstance(val, OrderByGroupBy): + d = f'Dimension("{val.name}")' + if val.grain: + grain_str = val.grain.name.lower() + d += f'.grain("{grain_str}")' + if val.descending: + d += ".descending(True)" + return d + + return json.dumps(val) + + @classmethod + def _serialize_params_dict(cls, params: Mapping[str, Any], param_names: List[str]) -> str: param_names_sorted = list(param_names) param_names_sorted.sort() def append_param_if_exists(p_str: str, p_name: str) -> str: p_value = params.get(p_name) if p_value is not None: - if isinstance(p_value, bool): - dumped = str(p_value) - else: - dumped = json.dumps(p_value) - p_str += f"{p_name}={dumped}," + serialized = cls._serialize_val(p_value) + p_str += f"{p_name}={serialized}," return p_str serialized_params = "" @@ -33,12 +63,15 @@ def append_param_if_exists(p_str: str, p_name: str) -> str: @classmethod def get_query_sql(cls, params: QueryParameters) -> str: """Get the SQL that will be sent via Arrow Flight to the server based on query parameters.""" - validate_query_parameters(params) - serialized_params = cls._serialize_params_dict(params, QueryParameters.__optional_keys__) + strict_params = validate_query_parameters(params) + params_fields = [f.name for f in dataclasses.fields(strict_params)] + strict_params_dict = {field: getattr(strict_params, field) for field in params_fields} + + serialized_params = cls._serialize_params_dict(strict_params_dict, params_fields) return f"SELECT * FROM {{{{ semantic_layer.query({serialized_params}) }}}}" @classmethod def get_dimension_values_sql(cls, params: DimensionValuesQueryParameters) -> str: """Get the SQL that will be sent via Arrow Flight to the server based on dimension values query parameters.""" - serialized_params = cls._serialize_params_dict(params, DimensionValuesQueryParameters.__optional_keys__) + serialized_params = cls._serialize_params_dict(params, list(DimensionValuesQueryParameters.__optional_keys__)) return f"SELECT * FROM {{{{ semantic_layer.dimension_values({serialized_params}) }}}}" diff --git a/dbtsl/api/graphql/client/asyncio.pyi b/dbtsl/api/graphql/client/asyncio.pyi index 36af1a3..7006af2 100644 --- a/dbtsl/api/graphql/client/asyncio.pyi +++ b/dbtsl/api/graphql/client/asyncio.pyi @@ -1,10 +1,10 @@ from contextlib import AbstractAsyncContextManager -from typing import List, Optional, Self +from typing import List, Optional, Self, Union import pyarrow as pa from typing_extensions import AsyncIterator, Unpack, overload -from dbtsl.api.shared.query_params import QueryParameters +from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric, QueryParameters from dbtsl.models import ( Dimension, Entity, @@ -50,7 +50,7 @@ class AsyncGraphQLClient: metrics: List[str], group_by: Optional[List[str]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> str: ... @@ -59,7 +59,7 @@ class AsyncGraphQLClient: self, saved_query: str, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> str: ... @@ -73,7 +73,7 @@ class AsyncGraphQLClient: metrics: List[str], group_by: Optional[List[str]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> "pa.Table": ... @@ -82,7 +82,7 @@ class AsyncGraphQLClient: self, saved_query: str, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> "pa.Table": ... diff --git a/dbtsl/api/graphql/client/sync.pyi b/dbtsl/api/graphql/client/sync.pyi index b050d62..c424dc0 100644 --- a/dbtsl/api/graphql/client/sync.pyi +++ b/dbtsl/api/graphql/client/sync.pyi @@ -1,10 +1,10 @@ from contextlib import AbstractContextManager -from typing import Iterator, List, Optional +from typing import Iterator, List, Optional, Union import pyarrow as pa from typing_extensions import Self, Unpack, overload -from dbtsl.api.shared.query_params import QueryParameters +from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric, QueryParameters from dbtsl.models import ( Dimension, Entity, @@ -50,7 +50,7 @@ class SyncGraphQLClient: metrics: List[str], group_by: Optional[List[str]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> str: ... @@ -59,7 +59,7 @@ class SyncGraphQLClient: self, saved_query: str, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> str: ... @@ -73,7 +73,7 @@ class SyncGraphQLClient: metrics: List[str], group_by: Optional[List[str]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> "pa.Table": ... @@ -82,7 +82,7 @@ class SyncGraphQLClient: self, saved_query: str, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> "pa.Table": ... diff --git a/dbtsl/api/graphql/protocol.py b/dbtsl/api/graphql/protocol.py index 417fc6f..b812ba7 100644 --- a/dbtsl/api/graphql/protocol.py +++ b/dbtsl/api/graphql/protocol.py @@ -5,7 +5,12 @@ from typing_extensions import NotRequired, override from dbtsl.api.graphql.util import render_query -from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters +from dbtsl.api.shared.query_params import ( + AdhocQueryParametersStrict, + OrderByMetric, + QueryParameters, + validate_query_parameters, +) from dbtsl.models import Dimension, Entity, Measure, Metric from dbtsl.models.query import QueryId, QueryResult, QueryStatus from dbtsl.models.saved_query import SavedQuery @@ -192,6 +197,42 @@ def parse_response(self, data: Dict[str, Any]) -> List[SavedQuery]: return decode_to_dataclass(data["savedQueries"], List[SavedQuery]) +def get_query_request_variables(environment_id: int, params: QueryParameters) -> Dict[str, Any]: + """Get the GraphQL request variables for a given set of query parameters.""" + strict_params = validate_query_parameters(params) # type: ignore + + shared_vars = { + "environmentId": environment_id, + "where": [{"sql": sql} for sql in strict_params.where] if strict_params.where is not None else None, + "orderBy": [ + {"metric": {"name": clause.name}, "descending": clause.descending} + if isinstance(clause, OrderByMetric) + else {"groupBy": {"name": clause.name, "grain": clause.grain}, "descending": clause.descending} + for clause in strict_params.order_by + ] + if strict_params.order_by is not None + else None, + "limit": strict_params.limit, + "readCache": strict_params.read_cache, + } + + if isinstance(strict_params, AdhocQueryParametersStrict): + return { + "savedQuery": None, + "metrics": [{"name": m} for m in strict_params.metrics], + "groupBy": [{"name": g} for g in strict_params.group_by] if strict_params.group_by is not None else None, + **shared_vars, + } + + return { + "environmentId": environment_id, + "savedQuery": strict_params.saved_query, + "metrics": None, + "groupBy": None, + **shared_vars, + } + + class CreateQueryOperation(ProtocolOperation[QueryParameters, QueryId]): """Create a query that will be processed asynchronously.""" @@ -227,17 +268,7 @@ def get_request_text(self) -> str: @override def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]: # TODO: fix typing - validate_query_parameters(kwargs) # type: ignore - return { - "environmentId": environment_id, - "savedQuery": kwargs.get("saved_query", None), - "metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None, - "groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None, - "where": [{"sql": sql} for sql in kwargs.get("where", [])], - "orderBy": [{"name": o} for o in kwargs.get("order_by", [])], - "limit": kwargs.get("limit", None), - "readCache": kwargs.get("read_cache", True), - } + return get_query_request_variables(environment_id, kwargs) # type: ignore @override def parse_response(self, data: Dict[str, Any]) -> QueryId: @@ -317,17 +348,7 @@ def get_request_text(self) -> str: @override def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]: # TODO: fix typing - validate_query_parameters(kwargs) # type: ignore - return { - "environmentId": environment_id, - "savedQuery": kwargs.get("saved_query", None), - "metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None, - "groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None, - "where": [{"sql": sql} for sql in kwargs.get("where", [])], - "orderBy": [{"name": o} for o in kwargs.get("order_by", [])], - "limit": kwargs.get("limit", None), - "readCache": kwargs.get("read_cache", True), - } + return get_query_request_variables(environment_id, kwargs) # type: ignore @override def parse_response(self, data: Dict[str, Any]) -> str: diff --git a/dbtsl/api/shared/query_params.py b/dbtsl/api/shared/query_params.py index 5a02016..083065b 100644 --- a/dbtsl/api/shared/query_params.py +++ b/dbtsl/api/shared/query_params.py @@ -1,4 +1,30 @@ -from typing import List, TypedDict +from dataclasses import dataclass +from typing import List, Optional, TypedDict, Union + +from dbtsl.models.time import TimeGranularity + + +@dataclass(frozen=True) +class OrderByMetric: + """Spec for ordering by a metric.""" + + name: str + descending: bool = False + + +@dataclass(frozen=True) +class OrderByGroupBy: + """Spec for ordering by a group_by, i.e a dimension or an entity. + + Not specifying a grain will defer the grain choice to the server. + """ + + name: str + grain: Optional[TimeGranularity] + descending: bool = False + + +OrderBySpec = Union[OrderByMetric, OrderByGroupBy] class QueryParameters(TypedDict, total=False): @@ -11,12 +37,63 @@ class QueryParameters(TypedDict, total=False): metrics: List[str] group_by: List[str] limit: int - order_by: List[str] + order_by: List[Union[OrderBySpec, str]] where: List[str] read_cache: bool -def validate_query_parameters(params: QueryParameters) -> None: +@dataclass(frozen=True) +class AdhocQueryParametersStrict: + """The parameters of an adhoc query, strictly validated.""" + + metrics: List[str] + group_by: Optional[List[str]] + limit: Optional[int] + order_by: Optional[List[OrderBySpec]] + where: Optional[List[str]] + read_cache: bool + + +@dataclass(frozen=True) +class SavedQueryQueryParametersStrict: + """The parameters of a query that uses a saved query, strictly validated.""" + + saved_query: str + limit: Optional[int] + order_by: Optional[List[OrderBySpec]] + where: Optional[List[str]] + read_cache: bool + + +def validate_order_by( + known_metrics: List[str], known_group_bys: List[str], clause: Union[OrderBySpec, str] +) -> OrderBySpec: + """Validate an order by clause like `-metric_name`.""" + if isinstance(clause, OrderByMetric) or isinstance(clause, OrderByGroupBy): + return clause + + descending = clause.startswith("-") + if descending or clause.startswith("+"): + clause = clause[1:] + + if clause in known_metrics: + return OrderByMetric(name=clause, descending=descending) + + if clause in known_group_bys or clause == "metric_time": + return OrderByGroupBy(name=clause, descending=descending, grain=None) + + # TODO: make this error less strict when server supports order_by type inference. + raise ValueError( + f"Cannot determine if the specified order_by clause ({clause}) is a metric or a dimension/entity. " + "If you're running an adhoc query, make sure the order_by is in `metrics` or `group_by`. " + "If you're using saved queries, please explicitly specify what you want by using " + "`dbtsl.OrderByMetric` or `dbtsl.OrderByGroupBy` instead of a string." + ) + + +def validate_query_parameters( + params: QueryParameters, +) -> Union[AdhocQueryParametersStrict, SavedQueryQueryParametersStrict]: """Validate a dict that should be QueryParameters.""" is_saved_query = "saved_query" in params is_adhoc_query = "metrics" in params or "group_by" in params @@ -27,11 +104,37 @@ def validate_query_parameters(params: QueryParameters) -> None: "metrics and group_by." ) - if "metrics" in params and len(params["metrics"]) == 0: + if not is_saved_query and not is_adhoc_query: + raise ValueError("You must specify one of: saved_query, metrics/group_by.") + + order_by: Optional[List[OrderBySpec]] = None + if "order_by" in params: + known_metrics = params.get("metrics", []) + known_group_bys = params.get("group_by", []) + + order_by = [validate_order_by(known_metrics, known_group_bys, clause) for clause in params["order_by"]] + + shared_params = { + "limit": params.get("limit"), + "order_by": order_by, + "where": params.get("where"), + "read_cache": params.get("read_cache", True), + } + + if is_saved_query: + return SavedQueryQueryParametersStrict( + saved_query=params["saved_query"], + **shared_params, + ) + + if "metrics" not in params or len(params["metrics"]) == 0: raise ValueError("You need to specify at least one metric.") - if "group_by" in params and len(params["group_by"]) == 0: - raise ValueError("You need to specify at least one dimension to group by.") + return AdhocQueryParametersStrict( + metrics=params["metrics"], + group_by=params.get("group_by"), + **shared_params, + ) class DimensionValuesQueryParameters(TypedDict, total=False): diff --git a/dbtsl/client/asyncio.pyi b/dbtsl/client/asyncio.pyi index cbe9293..08a78f2 100644 --- a/dbtsl/client/asyncio.pyi +++ b/dbtsl/client/asyncio.pyi @@ -1,10 +1,10 @@ from contextlib import AbstractAsyncContextManager -from typing import AsyncIterator, List, Optional +from typing import AsyncIterator, List, Optional, Union import pyarrow as pa from typing_extensions import Self, Unpack, overload -from dbtsl.api.adbc.protocol import QueryParameters +from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric, QueryParameters from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery class AsyncSemanticLayerClient: @@ -20,7 +20,7 @@ class AsyncSemanticLayerClient: metrics: List[str], group_by: Optional[List[str]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> str: ... @@ -29,7 +29,7 @@ class AsyncSemanticLayerClient: self, saved_query: str, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> str: ... @@ -43,7 +43,7 @@ class AsyncSemanticLayerClient: metrics: List[str], group_by: Optional[List[str]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> "pa.Table": ... @@ -52,7 +52,7 @@ class AsyncSemanticLayerClient: self, saved_query: str, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> "pa.Table": ... diff --git a/dbtsl/client/sync.pyi b/dbtsl/client/sync.pyi index 2546064..1b0f577 100644 --- a/dbtsl/client/sync.pyi +++ b/dbtsl/client/sync.pyi @@ -1,10 +1,10 @@ from contextlib import AbstractContextManager -from typing import Iterator, List, Optional +from typing import Iterator, List, Optional, Union import pyarrow as pa from typing_extensions import Self, Unpack, overload -from dbtsl.api.adbc.protocol import QueryParameters +from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric, QueryParameters from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery class SyncSemanticLayerClient: @@ -20,7 +20,7 @@ class SyncSemanticLayerClient: metrics: List[str], group_by: Optional[List[str]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> str: ... @@ -29,7 +29,7 @@ class SyncSemanticLayerClient: self, saved_query: str, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> str: ... @@ -43,7 +43,7 @@ class SyncSemanticLayerClient: metrics: List[str], group_by: Optional[List[str]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> "pa.Table": ... @@ -52,7 +52,7 @@ class SyncSemanticLayerClient: self, saved_query: str, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, + order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None, where: Optional[List[str]] = None, read_cache: bool = True, ) -> "pa.Table": ... diff --git a/tests/api/adbc/test_protocol.py b/tests/api/adbc/test_protocol.py index 7495916..fec3e74 100644 --- a/tests/api/adbc/test_protocol.py +++ b/tests/api/adbc/test_protocol.py @@ -1,20 +1,48 @@ from dbtsl.api.adbc.protocol import ADBCProtocol -from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters +from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric +from dbtsl.models.time import TimeGranularity -def test_serialize_query_params_simple_query() -> None: - params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"]}, QueryParameters.__optional_keys__) +def test_serialize_val_basic_values() -> None: + assert ADBCProtocol._serialize_val(1) == "1" + assert ADBCProtocol._serialize_val("a") == '"a"' + assert ADBCProtocol._serialize_val(True) == "True" + assert ADBCProtocol._serialize_val(False) == "False" + assert ADBCProtocol._serialize_val(["a", "b"]) == '["a","b"]' - expected = 'metrics=["a", "b"]' - assert params == expected +def test_serialize_val_OrderByMetric() -> None: + assert ADBCProtocol._serialize_val(OrderByMetric(name="m", descending=False)) == 'Metric("m")' + assert ADBCProtocol._serialize_val(OrderByMetric(name="m", descending=True)) == 'Metric("m").descending(True)' -def test_serialize_query_params_dimensions_query() -> None: - params = ADBCProtocol._serialize_params_dict( - {"metrics": ["a", "b"], "group_by": "c"}, DimensionValuesQueryParameters.__optional_keys__ + +def test_serialize_val_OrderByGroupBy() -> None: + assert ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=None, descending=False)) == 'Dimension("m")' + assert ( + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=None, descending=True)) + == 'Dimension("m").descending(True)' + ) + assert ( + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.DAY, descending=False)) + == 'Dimension("m").grain("day")' ) + assert ( + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.WEEK, descending=True)) + == 'Dimension("m").grain("week").descending(True)' + ) + + +def test_serialize_query_params_metrics() -> None: + params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"]}, ["metrics"]) - expected = 'group_by="c",metrics=["a", "b"]' + expected = 'metrics=["a","b"]' + assert params == expected + + +def test_serialize_query_params_metrics_group_by() -> None: + params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"], "group_by": "c"}, ["metrics", "group_by"]) + + expected = 'group_by="c",metrics=["a","b"]' assert params == expected @@ -24,27 +52,30 @@ def test_serialize_query_params_complete_query() -> None: "metrics": ["a", "b"], "group_by": ["dim_c"], "limit": 1, - "order_by": ["dim_c"], + "order_by": [OrderByMetric(name="a"), OrderByGroupBy(name="dim_c", grain=None)], "where": ['{{ Dimension("metric_time").grain("month") }} >= \'2017-03-09\''], "read_cache": False, }, - QueryParameters.__optional_keys__, + ["metrics", "group_by", "limit", "order_by", "where", "read_cache"], ) expected = ( - 'group_by=["dim_c"],limit=1,metrics=["a", "b"],order_by=["dim_c"],read_cache=False,' + 'group_by=["dim_c"],limit=1,metrics=["a","b"],order_by=[Metric("a"),Dimension("dim_c")],read_cache=False,' 'where=["{{ Dimension(\\"metric_time\\").grain(\\"month\\") }} >= \'2017-03-09\'"]' ) assert params == expected def test_get_query_sql_simple_query() -> None: - sql = ADBCProtocol.get_query_sql(params={"metrics": ["a", "b"]}) - expected = 'SELECT * FROM {{ semantic_layer.query(metrics=["a", "b"]) }}' + sql = ADBCProtocol.get_query_sql(params={"metrics": ["a", "b"], "order_by": ["-a"]}) + expected = ( + 'SELECT * FROM {{ semantic_layer.query(metrics=["a","b"],' + 'order_by=[Metric("a").descending(True)],read_cache=True) }}' + ) assert sql == expected def test_get_query_sql_dimension_values_query() -> None: sql = ADBCProtocol.get_dimension_values_sql(params={"metrics": ["a", "b"]}) - expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a", "b"]) }}' + expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a","b"]) }}' assert sql == expected diff --git a/tests/integration/test_sl_client.py b/tests/integration/test_sl_client.py index f67b7c7..343a073 100644 --- a/tests/integration/test_sl_client.py +++ b/tests/integration/test_sl_client.py @@ -3,6 +3,7 @@ import pytest from pytest_subtests import SubTests +from dbtsl import OrderByGroupBy from dbtsl.client.asyncio import AsyncSemanticLayerClient from dbtsl.client.base import ADBC, GRAPHQL from dbtsl.client.sync import SyncSemanticLayerClient @@ -87,7 +88,10 @@ async def test_client_query_adhoc(api: str, client: BothClients) -> None: client.query( metrics=[metrics[0].name], group_by=["metric_time"], + order_by=["metric_time"], + where=["1=1"], limit=1, + read_cache=True, ) ) assert len(table) > 0 @@ -102,7 +106,10 @@ async def test_client_query_saved_query(api: str, client: BothClients) -> None: table = await maybe_await( client.query( saved_query="order_metrics", + order_by=[OrderByGroupBy(name="metric_time", grain=None)], + where=["1=1"], limit=1, + read_cache=True, ) ) assert len(table) > 0 @@ -116,7 +123,10 @@ async def test_client_compile_sql_adhoc_query(client: BothClients) -> None: client.compile_sql( metrics=[metrics[0].name], group_by=[metrics[0].dimensions[0].name], + order_by=[metrics[0].dimensions[0].name], + where=["1=1"], limit=1, + read_cache=True, ) ) assert len(sql) > 0 @@ -130,7 +140,10 @@ async def test_client_compile_sql_saved_query(client: BothClients) -> None: sql = await maybe_await( client.compile_sql( saved_query="order_metrics", + order_by=[OrderByGroupBy(name="metric_time", grain=None)], + where=["1=1"], limit=1, + read_cache=True, ) ) assert len(sql) > 0 diff --git a/tests/test_models.py b/tests/test_models.py index 5f6a492..50f1957 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,7 +5,15 @@ from mashumaro.codecs.basic import decode from dbtsl.api.graphql.util import normalize_query -from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters +from dbtsl.api.shared.query_params import ( + AdhocQueryParametersStrict, + OrderByGroupBy, + OrderByMetric, + QueryParameters, + SavedQueryQueryParametersStrict, + validate_order_by, + validate_query_parameters, +) from dbtsl.models.base import BaseModel, GraphQLFragmentMixin from dbtsl.models.base import snake_case_to_camel_case as stc @@ -81,17 +89,84 @@ class B(BaseModel, GraphQLFragmentMixin): assert b_fragments[1] == a_fragment +def test_validate_order_by_params_passthrough_OrderByMetric() -> None: + i = OrderByMetric(name="asdf", descending=True) + r = validate_order_by([], [], i) + assert r == i + + +def test_validate_order_by_params_passthrough_OrderByGroupBy() -> None: + i = OrderByGroupBy(name="asdf", grain=None, descending=True) + r = validate_order_by([], [], i) + assert r == i + + +def test_validate_order_by_params_ascending() -> None: + r = validate_order_by(["metric"], [], "+metric") + assert r == OrderByMetric(name="metric", descending=False) + + +def test_validate_order_by_params_descending() -> None: + r = validate_order_by(["metric"], [], "-metric") + assert r == OrderByMetric(name="metric", descending=True) + + +def test_validate_order_by_params_metric() -> None: + r = validate_order_by(["a"], ["b"], "a") + assert r == OrderByMetric( + name="a", + descending=False, + ) + + +def test_validate_order_by_params_group_by() -> None: + r = validate_order_by(["a"], ["b"], "b") + assert r == OrderByGroupBy( + name="b", + grain=None, + descending=False, + ) + + +def test_validate_order_by_not_found() -> None: + with pytest.raises(ValueError): + validate_order_by(["a"], ["b"], "c") + + def test_validate_query_params_adhoc_query_valid() -> None: p: QueryParameters = { "metrics": ["a", "b"], - "group_by": ["a", "b"], + "group_by": ["c", "d"], + "order_by": ["a"], + "where": ["1=1"], + "limit": 1, + "read_cache": False, } - validate_query_parameters(p) + r = validate_query_parameters(p) + assert isinstance(r, AdhocQueryParametersStrict) + assert r.metrics == ["a", "b"] + assert r.group_by == ["c", "d"] + assert r.order_by == [OrderByMetric(name="a")] + assert r.where == ["1=1"] + assert r.limit == 1 + assert not r.read_cache def test_validate_query_params_saved_query_valid() -> None: - p: QueryParameters = {"saved_query": "a"} - validate_query_parameters(p) + p: QueryParameters = { + "saved_query": "a", + "order_by": [OrderByMetric(name="b")], + "where": ["1=1"], + "limit": 1, + "read_cache": False, + } + r = validate_query_parameters(p) + assert isinstance(r, SavedQueryQueryParametersStrict) + assert r.saved_query == "a" + assert r.order_by == [OrderByMetric(name="b")] + assert r.where == ["1=1"] + assert r.limit == 1 + assert not r.read_cache def test_validate_query_params_adhoc_query_no_metrics() -> None: @@ -103,16 +178,13 @@ def test_validate_query_params_adhoc_query_no_metrics() -> None: validate_query_parameters(p) -def test_validate_query_params_adhoc_query_no_group_by() -> None: - p: QueryParameters = { - "metrics": ["a", "b"], - "group_by": [], - } +def test_validate_query_params_adhoc_and_saved_query() -> None: + p: QueryParameters = {"metrics": ["a", "b"], "group_by": ["a", "b"], "saved_query": "a"} with pytest.raises(ValueError): validate_query_parameters(p) -def test_validate_query_params_adhoc_and_saved_query() -> None: - p: QueryParameters = {"metrics": ["a", "b"], "group_by": ["a", "b"], "saved_query": "a"} +def test_validate_query_params_no_query() -> None: + p: QueryParameters = {"limit": 1, "where": ["1=1"], "order_by": ["a"], "read_cache": False} with pytest.raises(ValueError): validate_query_parameters(p)