Skip to content

Commit

Permalink
fix: order_by for compile_sql now works as expected (#49)
Browse files Browse the repository at this point in the history
* test: all parameters in `query` and `compile_sql`

The tests for `query` and `compile_sql` did not test all allowed
parameters, which made the `order_by` bug go unnoticed. This commit
fixes that by adding all the remaining parameters to the tests.

Note that this commit is in a broken state since the fix hasn't been
applied yet. That will come in a future patch.

* refactor: better internal repr of query params

This commit improves our validation and representation of query
parameters, and fixes the bug with `order_by`.

We're still in an inconsistent state: we gotta propagate the changes and
use them in other classes. Will do that in the following patch.

* refactor: make protocols use the new validations

This commit makes the ADBC and GraphQL protocol implementations use the
new stricter representation for query params. It updates the tests
accordingly.

* refactor: typing public interfaces with order by

This commit updates the public `.pyi` files to ensure we use the new
order by spec

* docs: changelog

Added changelog entries related to the order by changes.
  • Loading branch information
serramatutu authored Oct 16, 2024
1 parent 992eb90 commit 9896a1c
Show file tree
Hide file tree
Showing 15 changed files with 380 additions and 93 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Breaking Changes-20241001-160946.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Breaking Changes
body: '`order_by` clause of queries using saved queries no longer support string inputs and require explicit `OrderByMetric` or `OrderByGroupBy`'
time: 2024-10-01T16:09:46.752151+02:00
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20241001-155522.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: Specifying order by with `OrderByMetric` and `OrderByGroupBy`
time: 2024-10-01T15:55:22.041799+02:00
3 changes: 3 additions & 0 deletions .changes/unreleased/Fixes-20241001-155448.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Fixes
body: Order by for `compile_sql` now works as expected
time: 2024-10-01T15:54:48.740022+02:00
3 changes: 3 additions & 0 deletions .changes/unreleased/Under the Hood-20241001-155544.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Under the Hood
body: Improved internal representation of query parameters and added better validation
time: 2024-10-01T15:55:44.855697+02:00
4 changes: 3 additions & 1 deletion dbtsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ def err_factory(*args, **kwargs) -> None: # noqa: D103

SemanticLayerClient = err_factory

__all__ = ["SemanticLayerClient"]
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric

__all__ = ["SemanticLayerClient", "OrderByMetric", "OrderByGroupBy"]
57 changes: 45 additions & 12 deletions dbtsl/api/adbc/protocol.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,55 @@
import dataclasses
import json
from typing import Any, FrozenSet, Mapping
from typing import Any, List, Mapping

from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters, validate_query_parameters
from dbtsl.api.shared.query_params import (
DimensionValuesQueryParameters,
OrderByGroupBy,
OrderByMetric,
QueryParameters,
validate_query_parameters,
)


class ADBCProtocol:
"""The protocol for the Arrow Flight dataframe API."""

@staticmethod
def _serialize_params_dict(params: Mapping[str, Any], param_names: FrozenSet[str]) -> str:
@classmethod
def _serialize_val(cls, val: Any) -> str:
if isinstance(val, bool):
return str(val)

if isinstance(val, list):
list_str = ",".join(cls._serialize_val(list_val) for list_val in val)
return f"[{list_str}]"

if isinstance(val, OrderByMetric):
m = f'Metric("{val.name}")'
if val.descending:
m += ".descending(True)"
return m

if isinstance(val, OrderByGroupBy):
d = f'Dimension("{val.name}")'
if val.grain:
grain_str = val.grain.name.lower()
d += f'.grain("{grain_str}")'
if val.descending:
d += ".descending(True)"
return d

return json.dumps(val)

@classmethod
def _serialize_params_dict(cls, params: Mapping[str, Any], param_names: List[str]) -> str:
param_names_sorted = list(param_names)
param_names_sorted.sort()

def append_param_if_exists(p_str: str, p_name: str) -> str:
p_value = params.get(p_name)
if p_value is not None:
if isinstance(p_value, bool):
dumped = str(p_value)
else:
dumped = json.dumps(p_value)
p_str += f"{p_name}={dumped},"
serialized = cls._serialize_val(p_value)
p_str += f"{p_name}={serialized},"
return p_str

serialized_params = ""
Expand All @@ -33,12 +63,15 @@ def append_param_if_exists(p_str: str, p_name: str) -> str:
@classmethod
def get_query_sql(cls, params: QueryParameters) -> str:
"""Get the SQL that will be sent via Arrow Flight to the server based on query parameters."""
validate_query_parameters(params)
serialized_params = cls._serialize_params_dict(params, QueryParameters.__optional_keys__)
strict_params = validate_query_parameters(params)
params_fields = [f.name for f in dataclasses.fields(strict_params)]
strict_params_dict = {field: getattr(strict_params, field) for field in params_fields}

serialized_params = cls._serialize_params_dict(strict_params_dict, params_fields)
return f"SELECT * FROM {{{{ semantic_layer.query({serialized_params}) }}}}"

@classmethod
def get_dimension_values_sql(cls, params: DimensionValuesQueryParameters) -> str:
"""Get the SQL that will be sent via Arrow Flight to the server based on dimension values query parameters."""
serialized_params = cls._serialize_params_dict(params, DimensionValuesQueryParameters.__optional_keys__)
serialized_params = cls._serialize_params_dict(params, list(DimensionValuesQueryParameters.__optional_keys__))
return f"SELECT * FROM {{{{ semantic_layer.dimension_values({serialized_params}) }}}}"
12 changes: 6 additions & 6 deletions dbtsl/api/graphql/client/asyncio.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from contextlib import AbstractAsyncContextManager
from typing import List, Optional, Self
from typing import List, Optional, Self, Union

import pyarrow as pa
from typing_extensions import AsyncIterator, Unpack, overload

from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric, QueryParameters
from dbtsl.models import (
Dimension,
Entity,
Expand Down Expand Up @@ -50,7 +50,7 @@ class AsyncGraphQLClient:
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
Expand All @@ -59,7 +59,7 @@ class AsyncGraphQLClient:
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
Expand All @@ -73,7 +73,7 @@ class AsyncGraphQLClient:
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
Expand All @@ -82,7 +82,7 @@ class AsyncGraphQLClient:
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
Expand Down
12 changes: 6 additions & 6 deletions dbtsl/api/graphql/client/sync.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from contextlib import AbstractContextManager
from typing import Iterator, List, Optional
from typing import Iterator, List, Optional, Union

import pyarrow as pa
from typing_extensions import Self, Unpack, overload

from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric, QueryParameters
from dbtsl.models import (
Dimension,
Entity,
Expand Down Expand Up @@ -50,7 +50,7 @@ class SyncGraphQLClient:
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
Expand All @@ -59,7 +59,7 @@ class SyncGraphQLClient:
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
Expand All @@ -73,7 +73,7 @@ class SyncGraphQLClient:
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
order_by: Optional[List[Union[str, OrderByGroupBy, OrderByMetric]]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
Expand All @@ -82,7 +82,7 @@ class SyncGraphQLClient:
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
order_by: Optional[List[Union[OrderByGroupBy, OrderByMetric]]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
Expand Down
67 changes: 44 additions & 23 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from typing_extensions import NotRequired, override

from dbtsl.api.graphql.util import render_query
from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters
from dbtsl.api.shared.query_params import (
AdhocQueryParametersStrict,
OrderByMetric,
QueryParameters,
validate_query_parameters,
)
from dbtsl.models import Dimension, Entity, Measure, Metric
from dbtsl.models.query import QueryId, QueryResult, QueryStatus
from dbtsl.models.saved_query import SavedQuery
Expand Down Expand Up @@ -192,6 +197,42 @@ def parse_response(self, data: Dict[str, Any]) -> List[SavedQuery]:
return decode_to_dataclass(data["savedQueries"], List[SavedQuery])


def get_query_request_variables(environment_id: int, params: QueryParameters) -> Dict[str, Any]:
"""Get the GraphQL request variables for a given set of query parameters."""
strict_params = validate_query_parameters(params) # type: ignore

shared_vars = {
"environmentId": environment_id,
"where": [{"sql": sql} for sql in strict_params.where] if strict_params.where is not None else None,
"orderBy": [
{"metric": {"name": clause.name}, "descending": clause.descending}
if isinstance(clause, OrderByMetric)
else {"groupBy": {"name": clause.name, "grain": clause.grain}, "descending": clause.descending}
for clause in strict_params.order_by
]
if strict_params.order_by is not None
else None,
"limit": strict_params.limit,
"readCache": strict_params.read_cache,
}

if isinstance(strict_params, AdhocQueryParametersStrict):
return {
"savedQuery": None,
"metrics": [{"name": m} for m in strict_params.metrics],
"groupBy": [{"name": g} for g in strict_params.group_by] if strict_params.group_by is not None else None,
**shared_vars,
}

return {
"environmentId": environment_id,
"savedQuery": strict_params.saved_query,
"metrics": None,
"groupBy": None,
**shared_vars,
}


class CreateQueryOperation(ProtocolOperation[QueryParameters, QueryId]):
"""Create a query that will be processed asynchronously."""

Expand Down Expand Up @@ -227,17 +268,7 @@ def get_request_text(self) -> str:
@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
# TODO: fix typing
validate_query_parameters(kwargs) # type: ignore
return {
"environmentId": environment_id,
"savedQuery": kwargs.get("saved_query", None),
"metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None,
"groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None,
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
"readCache": kwargs.get("read_cache", True),
}
return get_query_request_variables(environment_id, kwargs) # type: ignore

@override
def parse_response(self, data: Dict[str, Any]) -> QueryId:
Expand Down Expand Up @@ -317,17 +348,7 @@ def get_request_text(self) -> str:
@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
# TODO: fix typing
validate_query_parameters(kwargs) # type: ignore
return {
"environmentId": environment_id,
"savedQuery": kwargs.get("saved_query", None),
"metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None,
"groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None,
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
"readCache": kwargs.get("read_cache", True),
}
return get_query_request_variables(environment_id, kwargs) # type: ignore

@override
def parse_response(self, data: Dict[str, Any]) -> str:
Expand Down
Loading

0 comments on commit 9896a1c

Please sign in to comment.