Skip to content

Commit

Permalink
refactor: make protocols use the new validations
Browse files Browse the repository at this point in the history
This commit makes the ADBC and GraphQL protocol implementations use the
new stricter representation for query params. It updates the tests
accordingly.
  • Loading branch information
serramatutu committed Oct 16, 2024
1 parent 190386a commit bd500f0
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 50 deletions.
57 changes: 45 additions & 12 deletions dbtsl/api/adbc/protocol.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,55 @@
import dataclasses
import json
from typing import Any, FrozenSet, Mapping
from typing import Any, List, Mapping

from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters, validate_query_parameters
from dbtsl.api.shared.query_params import (
DimensionValuesQueryParameters,
OrderByGroupBy,
OrderByMetric,
QueryParameters,
validate_query_parameters,
)


class ADBCProtocol:
"""The protocol for the Arrow Flight dataframe API."""

@staticmethod
def _serialize_params_dict(params: Mapping[str, Any], param_names: FrozenSet[str]) -> str:
@classmethod
def _serialize_val(cls, val: Any) -> str:
if isinstance(val, bool):
return str(val)

if isinstance(val, list):
list_str = ",".join(cls._serialize_val(list_val) for list_val in val)
return f"[{list_str}]"

if isinstance(val, OrderByMetric):
m = f'Metric("{val.name}")'
if val.descending:
m += ".descending(True)"
return m

if isinstance(val, OrderByGroupBy):
d = f'Dimension("{val.name}")'
if val.grain:
grain_str = val.grain.name.lower()
d += f'.grain("{grain_str}")'
if val.descending:
d += ".descending(True)"
return d

return json.dumps(val)

@classmethod
def _serialize_params_dict(cls, params: Mapping[str, Any], param_names: List[str]) -> str:
param_names_sorted = list(param_names)
param_names_sorted.sort()

def append_param_if_exists(p_str: str, p_name: str) -> str:
p_value = params.get(p_name)
if p_value is not None:
if isinstance(p_value, bool):
dumped = str(p_value)
else:
dumped = json.dumps(p_value)
p_str += f"{p_name}={dumped},"
serialized = cls._serialize_val(p_value)
p_str += f"{p_name}={serialized},"
return p_str

serialized_params = ""
Expand All @@ -33,12 +63,15 @@ def append_param_if_exists(p_str: str, p_name: str) -> str:
@classmethod
def get_query_sql(cls, params: QueryParameters) -> str:
"""Get the SQL that will be sent via Arrow Flight to the server based on query parameters."""
validate_query_parameters(params)
serialized_params = cls._serialize_params_dict(params, QueryParameters.__optional_keys__)
strict_params = validate_query_parameters(params)
params_fields = [f.name for f in dataclasses.fields(strict_params)]
strict_params_dict = {field: getattr(strict_params, field) for field in params_fields}

serialized_params = cls._serialize_params_dict(strict_params_dict, params_fields)
return f"SELECT * FROM {{{{ semantic_layer.query({serialized_params}) }}}}"

@classmethod
def get_dimension_values_sql(cls, params: DimensionValuesQueryParameters) -> str:
"""Get the SQL that will be sent via Arrow Flight to the server based on dimension values query parameters."""
serialized_params = cls._serialize_params_dict(params, DimensionValuesQueryParameters.__optional_keys__)
serialized_params = cls._serialize_params_dict(params, list(DimensionValuesQueryParameters.__optional_keys__))
return f"SELECT * FROM {{{{ semantic_layer.dimension_values({serialized_params}) }}}}"
67 changes: 44 additions & 23 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from typing_extensions import NotRequired, override

from dbtsl.api.graphql.util import render_query
from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters
from dbtsl.api.shared.query_params import (
AdhocQueryParametersStrict,
OrderByMetric,
QueryParameters,
validate_query_parameters,
)
from dbtsl.models import Dimension, Entity, Measure, Metric
from dbtsl.models.query import QueryId, QueryResult, QueryStatus
from dbtsl.models.saved_query import SavedQuery
Expand Down Expand Up @@ -192,6 +197,42 @@ def parse_response(self, data: Dict[str, Any]) -> List[SavedQuery]:
return decode_to_dataclass(data["savedQueries"], List[SavedQuery])


def get_query_request_variables(environment_id: int, params: QueryParameters) -> Dict[str, Any]:
"""Get the GraphQL request variables for a given set of query parameters."""
strict_params = validate_query_parameters(params) # type: ignore

shared_vars = {
"environmentId": environment_id,
"where": [{"sql": sql} for sql in strict_params.where] if strict_params.where is not None else None,
"orderBy": [
{"metric": {"name": clause.name}, "descending": clause.descending}
if isinstance(clause, OrderByMetric)
else {"groupBy": {"name": clause.name, "grain": clause.grain}, "descending": clause.descending}
for clause in strict_params.order_by
]
if strict_params.order_by is not None
else None,
"limit": strict_params.limit,
"readCache": strict_params.read_cache,
}

if isinstance(strict_params, AdhocQueryParametersStrict):
return {
"savedQuery": None,
"metrics": [{"name": m} for m in strict_params.metrics],
"groupBy": [{"name": g} for g in strict_params.group_by] if strict_params.group_by is not None else None,
**shared_vars,
}

return {
"environmentId": environment_id,
"savedQuery": strict_params.saved_query,
"metrics": None,
"groupBy": None,
**shared_vars,
}


class CreateQueryOperation(ProtocolOperation[QueryParameters, QueryId]):
"""Create a query that will be processed asynchronously."""

Expand Down Expand Up @@ -227,17 +268,7 @@ def get_request_text(self) -> str:
@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
# TODO: fix typing
validate_query_parameters(kwargs) # type: ignore
return {
"environmentId": environment_id,
"savedQuery": kwargs.get("saved_query", None),
"metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None,
"groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None,
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
"readCache": kwargs.get("read_cache", True),
}
return get_query_request_variables(environment_id, kwargs) # type: ignore

@override
def parse_response(self, data: Dict[str, Any]) -> QueryId:
Expand Down Expand Up @@ -317,17 +348,7 @@ def get_request_text(self) -> str:
@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
# TODO: fix typing
validate_query_parameters(kwargs) # type: ignore
return {
"environmentId": environment_id,
"savedQuery": kwargs.get("saved_query", None),
"metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None,
"groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None,
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
"readCache": kwargs.get("read_cache", True),
}
return get_query_request_variables(environment_id, kwargs) # type: ignore

@override
def parse_response(self, data: Dict[str, Any]) -> str:
Expand Down
61 changes: 46 additions & 15 deletions tests/api/adbc/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,48 @@
from dbtsl.api.adbc.protocol import ADBCProtocol
from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric
from dbtsl.models.time import TimeGranularity


def test_serialize_query_params_simple_query() -> None:
params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"]}, QueryParameters.__optional_keys__)
def test_serialize_val_basic_values() -> None:
assert ADBCProtocol._serialize_val(1) == "1"
assert ADBCProtocol._serialize_val("a") == '"a"'
assert ADBCProtocol._serialize_val(True) == "True"
assert ADBCProtocol._serialize_val(False) == "False"
assert ADBCProtocol._serialize_val(["a", "b"]) == '["a","b"]'

expected = 'metrics=["a", "b"]'
assert params == expected

def test_serialize_val_OrderByMetric() -> None:
assert ADBCProtocol._serialize_val(OrderByMetric(name="m", descending=False)) == 'Metric("m")'
assert ADBCProtocol._serialize_val(OrderByMetric(name="m", descending=True)) == 'Metric("m").descending(True)'

def test_serialize_query_params_dimensions_query() -> None:
params = ADBCProtocol._serialize_params_dict(
{"metrics": ["a", "b"], "group_by": "c"}, DimensionValuesQueryParameters.__optional_keys__

def test_serialize_val_OrderByGroupBy() -> None:
assert ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=None, descending=False)) == 'Dimension("m")'
assert (
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=None, descending=True))
== 'Dimension("m").descending(True)'
)
assert (
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.DAY, descending=False))
== 'Dimension("m").grain("day")'
)
assert (
ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.WEEK, descending=True))
== 'Dimension("m").grain("week").descending(True)'
)


def test_serialize_query_params_metrics() -> None:
params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"]}, ["metrics"])

expected = 'group_by="c",metrics=["a", "b"]'
expected = 'metrics=["a","b"]'
assert params == expected


def test_serialize_query_params_metrics_group_by() -> None:
params = ADBCProtocol._serialize_params_dict({"metrics": ["a", "b"], "group_by": "c"}, ["metrics", "group_by"])

expected = 'group_by="c",metrics=["a","b"]'
assert params == expected


Expand All @@ -24,27 +52,30 @@ def test_serialize_query_params_complete_query() -> None:
"metrics": ["a", "b"],
"group_by": ["dim_c"],
"limit": 1,
"order_by": ["dim_c"],
"order_by": [OrderByMetric(name="a"), OrderByGroupBy(name="dim_c", grain=None)],
"where": ['{{ Dimension("metric_time").grain("month") }} >= \'2017-03-09\''],
"read_cache": False,
},
QueryParameters.__optional_keys__,
["metrics", "group_by", "limit", "order_by", "where", "read_cache"],
)

expected = (
'group_by=["dim_c"],limit=1,metrics=["a", "b"],order_by=["dim_c"],read_cache=False,'
'group_by=["dim_c"],limit=1,metrics=["a","b"],order_by=[Metric("a"),Dimension("dim_c")],read_cache=False,'
'where=["{{ Dimension(\\"metric_time\\").grain(\\"month\\") }} >= \'2017-03-09\'"]'
)
assert params == expected


def test_get_query_sql_simple_query() -> None:
sql = ADBCProtocol.get_query_sql(params={"metrics": ["a", "b"]})
expected = 'SELECT * FROM {{ semantic_layer.query(metrics=["a", "b"]) }}'
sql = ADBCProtocol.get_query_sql(params={"metrics": ["a", "b"], "order_by": ["-a"]})
expected = (
'SELECT * FROM {{ semantic_layer.query(metrics=["a","b"],'
'order_by=[Metric("a").descending(True)],read_cache=True) }}'
)
assert sql == expected


def test_get_query_sql_dimension_values_query() -> None:
sql = ADBCProtocol.get_dimension_values_sql(params={"metrics": ["a", "b"]})
expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a", "b"]) }}'
expected = 'SELECT * FROM {{ semantic_layer.dimension_values(metrics=["a","b"]) }}'
assert sql == expected

0 comments on commit bd500f0

Please sign in to comment.