diff --git a/dbtsl/api/adbc/protocol.py b/dbtsl/api/adbc/protocol.py index 0a89280..c49d176 100644 --- a/dbtsl/api/adbc/protocol.py +++ b/dbtsl/api/adbc/protocol.py @@ -1,25 +1,55 @@ +import dataclasses import json -from typing import Any, FrozenSet, Mapping +from typing import Any, List, Mapping -from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters, validate_query_parameters +from dbtsl.api.shared.query_params import ( + DimensionValuesQueryParameters, + OrderByGroupBy, + OrderByMetric, + QueryParameters, + validate_query_parameters, +) class ADBCProtocol: """The protocol for the Arrow Flight dataframe API.""" - @staticmethod - def _serialize_params_dict(params: Mapping[str, Any], param_names: FrozenSet[str]) -> str: + @classmethod + def _serialize_val(cls, val: Any) -> str: + if isinstance(val, bool): + return str(val) + + if isinstance(val, list): + list_str = ",".join(cls._serialize_val(list_val) for list_val in val) + return f"[{list_str}]" + + if isinstance(val, OrderByMetric): + m = f'Metric("{val.name}")' + if val.descending: + m += ".descending(True)" + return m + + if isinstance(val, OrderByGroupBy): + d = f'Dimension("{val.name}")' + if val.grain: + grain_str = val.grain.name.lower() + d += f'.grain("{grain_str}")' + if val.descending: + d += ".descending(True)" + return d + + return json.dumps(val) + + @classmethod + def _serialize_params_dict(cls, params: Mapping[str, Any], param_names: List[str]) -> str: param_names_sorted = list(param_names) param_names_sorted.sort() def append_param_if_exists(p_str: str, p_name: str) -> str: p_value = params.get(p_name) if p_value is not None: - if isinstance(p_value, bool): - dumped = str(p_value) - else: - dumped = json.dumps(p_value) - p_str += f"{p_name}={dumped}," + serialized = cls._serialize_val(p_value) + p_str += f"{p_name}={serialized}," return p_str serialized_params = "" @@ -33,12 +63,15 @@ def append_param_if_exists(p_str: str, p_name: str) -> str: @classmethod def get_query_sql(cls, params: QueryParameters) -> str: """Get the SQL that will be sent via Arrow Flight to the server based on query parameters.""" - validate_query_parameters(params) - serialized_params = cls._serialize_params_dict(params, QueryParameters.__optional_keys__) + strict_params = validate_query_parameters(params) + params_fields = [f.name for f in dataclasses.fields(strict_params)] + strict_params_dict = {field: getattr(strict_params, field) for field in params_fields} + + serialized_params = cls._serialize_params_dict(strict_params_dict, params_fields) return f"SELECT * FROM {{{{ semantic_layer.query({serialized_params}) }}}}" @classmethod def get_dimension_values_sql(cls, params: DimensionValuesQueryParameters) -> str: """Get the SQL that will be sent via Arrow Flight to the server based on dimension values query parameters.""" - serialized_params = cls._serialize_params_dict(params, DimensionValuesQueryParameters.__optional_keys__) + serialized_params = cls._serialize_params_dict(params, list(DimensionValuesQueryParameters.__optional_keys__)) return f"SELECT * FROM {{{{ semantic_layer.dimension_values({serialized_params}) }}}}" diff --git a/dbtsl/api/graphql/protocol.py b/dbtsl/api/graphql/protocol.py index 417fc6f..b812ba7 100644 --- a/dbtsl/api/graphql/protocol.py +++ b/dbtsl/api/graphql/protocol.py @@ -5,7 +5,12 @@ from typing_extensions import NotRequired, override from dbtsl.api.graphql.util import render_query -from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters +from dbtsl.api.shared.query_params import ( + AdhocQueryParametersStrict, + OrderByMetric, + QueryParameters, + validate_query_parameters, +) from dbtsl.models import Dimension, Entity, Measure, Metric from dbtsl.models.query import QueryId, QueryResult, QueryStatus from dbtsl.models.saved_query import SavedQuery @@ -192,6 +197,42 @@ def parse_response(self, data: Dict[str, Any]) -> List[SavedQuery]: return decode_to_dataclass(data["savedQueries"], List[SavedQuery]) +def get_query_request_variables(environment_id: int, params: QueryParameters) -> Dict[str, Any]: + """Get the GraphQL request variables for a given set of query parameters.""" + strict_params = validate_query_parameters(params) # type: ignore + + shared_vars = { + "environmentId": environment_id, + "where": [{"sql": sql} for sql in strict_params.where] if strict_params.where is not None else None, + "orderBy": [ + {"metric": {"name": clause.name}, "descending": clause.descending} + if isinstance(clause, OrderByMetric) + else {"groupBy": {"name": clause.name, "grain": clause.grain}, "descending": clause.descending} + for clause in strict_params.order_by + ] + if strict_params.order_by is not None + else None, + "limit": strict_params.limit, + "readCache": strict_params.read_cache, + } + + if isinstance(strict_params, AdhocQueryParametersStrict): + return { + "savedQuery": None, + "metrics": [{"name": m} for m in strict_params.metrics], + "groupBy": [{"name": g} for g in strict_params.group_by] if strict_params.group_by is not None else None, + **shared_vars, + } + + return { + "environmentId": environment_id, + "savedQuery": strict_params.saved_query, + "metrics": None, + "groupBy": None, + **shared_vars, + } + + class CreateQueryOperation(ProtocolOperation[QueryParameters, QueryId]): """Create a query that will be processed asynchronously.""" @@ -227,17 +268,7 @@ def get_request_text(self) -> str: @override def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]: # TODO: fix typing - validate_query_parameters(kwargs) # type: ignore - return { - "environmentId": environment_id, - "savedQuery": kwargs.get("saved_query", None), - "metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None, - "groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None, - "where": [{"sql": sql} for sql in kwargs.get("where", [])], - "orderBy": [{"name": o} for o in kwargs.get("order_by", [])], - "limit": kwargs.get("limit", None), - "readCache": kwargs.get("read_cache", True), - } + return get_query_request_variables(environment_id, kwargs) # type: ignore @override def parse_response(self, data: Dict[str, Any]) -> QueryId: @@ -317,17 +348,7 @@ def get_request_text(self) -> str: @override def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]: # TODO: fix typing - validate_query_parameters(kwargs) # type: ignore - return { - "environmentId": environment_id, - "savedQuery": kwargs.get("saved_query", None), - "metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None, - "groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None, - "where": [{"sql": sql} for sql in kwargs.get("where", [])], - "orderBy": [{"name": o} for o in kwargs.get("order_by", [])], - "limit": kwargs.get("limit", None), - "readCache": kwargs.get("read_cache", True), - } + return get_query_request_variables(environment_id, kwargs) # type: ignore @override def parse_response(self, data: Dict[str, Any]) -> str: diff --git a/tests/api/adbc/test_protocol.py b/tests/api/adbc/test_protocol.py index 7495916..fec3e74 100644 --- a/tests/api/adbc/test_protocol.py +++ b/tests/api/adbc/test_protocol.py @@ -1,20 +1,48 @@ from dbtsl.api.adbc.protocol import ADBCProtocol -from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters +from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric +from dbtsl.models.time import TimeGranularity -def test_serialize_query_params_simple_query() -> None: - params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"]}, QueryParameters.__optional_keys__) +def test_serialize_val_basic_values() -> None: + assert ADBCProtocol._serialize_val(1) == "1" + assert ADBCProtocol._serialize_val("a") == '"a"' + assert ADBCProtocol._serialize_val(True) == "True" + assert ADBCProtocol._serialize_val(False) == "False" + assert ADBCProtocol._serialize_val(["a", "b"]) == '["a","b"]' - expected = 'metrics=["a", "b"]' - assert params == expected +def test_serialize_val_OrderByMetric() -> None: + assert ADBCProtocol._serialize_val(OrderByMetric(name="m", descending=False)) == 'Metric("m")' + assert ADBCProtocol._serialize_val(OrderByMetric(name="m", descending=True)) == 'Metric("m").descending(True)' -def test_serialize_query_params_dimensions_query() -> None: - params = ADBCProtocol._serialize_params_dict( - {"metrics": ["a", "b"], "group_by": "c"}, DimensionValuesQueryParameters.__optional_keys__ + +def test_serialize_val_OrderByGroupBy() -> None: + assert ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=None, descending=False)) == 'Dimension("m")' + assert ( + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=None, descending=True)) + == 'Dimension("m").descending(True)' + ) + assert ( + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.DAY, descending=False)) + == 'Dimension("m").grain("day")' ) + assert ( + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.WEEK, descending=True)) + == 'Dimension("m").grain("week").descending(True)' + ) + + +def test_serialize_query_params_metrics() -> None: + params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"]}, ["metrics"]) - expected = 'group_by="c",metrics=["a", "b"]' + expected = 'metrics=["a","b"]' + assert params == expected + + +def test_serialize_query_params_metrics_group_by() -> None: + params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"], "group_by": "c"}, ["metrics", "group_by"]) + + expected = 'group_by="c",metrics=["a","b"]' assert params == expected @@ -24,27 +52,30 @@ def test_serialize_query_params_complete_query() -> None: "metrics": ["a", "b"], "group_by": ["dim_c"], "limit": 1, - "order_by": ["dim_c"], + "order_by": [OrderByMetric(name="a"), OrderByGroupBy(name="dim_c", grain=None)], "where": ['{{ Dimension("metric_time").grain("month") }} >= \'2017-03-09\''], "read_cache": False, }, - QueryParameters.__optional_keys__, + ["metrics", "group_by", "limit", "order_by", "where", "read_cache"], ) expected = ( - 'group_by=["dim_c"],limit=1,metrics=["a", "b"],order_by=["dim_c"],read_cache=False,' + 'group_by=["dim_c"],limit=1,metrics=["a","b"],order_by=[Metric("a"),Dimension("dim_c")],read_cache=False,' 'where=["{{ Dimension(\\"metric_time\\").grain(\\"month\\") }} >= \'2017-03-09\'"]' ) assert params == expected def test_get_query_sql_simple_query() -> None: - sql = ADBCProtocol.get_query_sql(params={"metrics": ["a", "b"]}) - expected = 'SELECT * FROM {{ semantic_layer.query(metrics=["a", "b"]) }}' + sql = ADBCProtocol.get_query_sql(params={"metrics": ["a", "b"], "order_by": ["-a"]}) + expected = ( + 'SELECT * FROM {{ semantic_layer.query(metrics=["a","b"],' + 'order_by=[Metric("a").descending(True)],read_cache=True) }}' + ) assert sql == expected def test_get_query_sql_dimension_values_query() -> None: sql = ADBCProtocol.get_dimension_values_sql(params={"metrics": ["a", "b"]}) - expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a", "b"]) }}' + expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a","b"]) }}' assert sql == expected