Skip to content

Commit

Permalink
fix: serialization issues in GraphQL orderBy and where (#57)
Browse files Browse the repository at this point in the history
* fix: serialization issues in GraphQL `orderBy` and `where`

This commit fixes some serialization issues that were showing up with
`orderBy` and `where` for the GraphQL protocol. This was affecting
`compile_sql()` mainly since queries go via ADBC.

To avoid regressions, I added checks for variable values in the GraphQL
tests. It will try to serialize those variable values and bind them to
the query using `graphql` utilities. If the query is not valid (i.e the
query + variables do not respect the server schema, it will fail. This
requires no connection to the server other than fetching the schema.

I also added integration tests that will connect to the server. This
will only run after unit test, though, so these issues should be caught
before integration tests.

* docs: add changelog entry

* fixup! fix: serialization issues in GraphQL `orderBy` and `where`

* fixup! fix: serialization issues in GraphQL `orderBy` and `where`
  • Loading branch information
serramatutu authored Nov 12, 2024
1 parent c6c0866 commit 6ed8f0b
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 86 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Fixes-20241111-182559.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Fixes
body: Fixes issue with `orderBy` and `where` when going via GraphQL
time: 2024-11-11T18:25:59.874757+01:00
10 changes: 5 additions & 5 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def get_query_request_variables(environment_id: int, params: QueryParameters) ->
"orderBy": [
{"metric": {"name": clause.name}, "descending": clause.descending}
if isinstance(clause, OrderByMetric)
else {"groupBy": {"name": clause.name, "grain": clause.grain}, "descending": clause.descending}
else {"groupBy": {"name": clause.name, "timeGranularity": clause.grain}, "descending": clause.descending}
for clause in strict_params.order_by
]
if strict_params.order_by is not None
Expand Down Expand Up @@ -244,8 +244,8 @@ def get_request_text(self) -> str:
$savedQuery: String,
$metrics: [MetricInput!],
$groupBy: [GroupByInput!],
$where: [WhereInput!]!,
$orderBy: [OrderByInput!]!,
$where: [WhereInput!],
$orderBy: [OrderByInput!],
$limit: Int,
$readCache: Boolean,
) {
Expand Down Expand Up @@ -324,8 +324,8 @@ def get_request_text(self) -> str:
$savedQuery: String,
$metrics: [MetricInput!],
$groupBy: [GroupByInput!],
$where: [WhereInput!]!,
$orderBy: [OrderByInput!]!,
$where: [WhereInput!],
$orderBy: [OrderByInput!],
$limit: Int,
$readCache: Boolean,
) {
Expand Down
54 changes: 42 additions & 12 deletions tests/api/graphql/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,52 @@
from pytest_subtests import SubTests
from typing import Any, Dict, List, Tuple

import pytest

from dbtsl.api.graphql.protocol import GraphQLProtocol

from ...conftest import QueryValidator
from ...query_test_cases import TEST_QUERIES

VARIABLES = {
"metrics": [{}],
"dimensions": [{"metrics": ["m"]}],
"measures": [{"metrics": ["m"]}],
"entities": [{"metrics": ["m"]}],
"saved_queries": [{}],
"get_query_result": [{"query_id": 1}],
"create_query": TEST_QUERIES,
"compile_sql": TEST_QUERIES,
}

TestCase = Tuple[str, Dict[str, Any]]
TEST_CASES: List[TestCase] = []

for op_name in dir(GraphQLProtocol):
if op_name.startswith("__"):
continue

tested_vars = VARIABLES.get(op_name)
assert tested_vars is not None, f"No test vars to use for testing GraphQLProtocol.{op_name}"

for variables in tested_vars:
TEST_CASES.append((op_name, variables))

def test_queries_are_valid(subtests: SubTests, validate_query: QueryValidator) -> None:

def get_test_id(test_case: TestCase) -> str:
return test_case[0]


@pytest.mark.parametrize("test_case", TEST_CASES, ids=get_test_id)
def test_queries_are_valid(test_case: TestCase, validate_query: QueryValidator) -> None:
"""Test all GraphQL queries in `GraphQLProtocol` are valid against the server schema.
This test dynamically iterates over `GraphQLProtocol` so whenever a new method is
added it will get tested automatically.
added it will get tested automatically. The test will fail if there's no entry in VARIABLES
for it.
"""
prop_names = dir(GraphQLProtocol)
for prop_name in prop_names:
if prop_name.startswith("__"):
continue

prop_val = getattr(GraphQLProtocol, prop_name)
with subtests.test(msg=f"GraphQLProtocol.{prop_name}"):
query = prop_val.get_request_text()
validate_query(query)
op_name, raw_variables = test_case

op = getattr(GraphQLProtocol, op_name)
query = op.get_request_text()
variable_values = op.get_request_variables(environment_id=123, **raw_variables)
validate_query(query, variable_values)
16 changes: 12 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from dataclasses import dataclass
from typing import Callable, Union, cast
from typing import Any, Callable, Dict, Union, cast

import pytest
from gql import Client, gql
from gql.utilities.serialize_variable_values import serialize_variable_values


def pytest_addoption(parser: pytest.Parser) -> None:
Expand All @@ -29,17 +30,24 @@ def server_schema(server_schema_path: str) -> str:
return schema_str


QueryValidator = Callable[[str], None]
QueryValidator = Callable[[str, Dict[str, None]], None]


@pytest.fixture(scope="session")
def validate_query(server_schema: str) -> QueryValidator:
"""Returns a validator function which ensures the query is valid against the server schema."""
"""Returns a validator function which ensures the query and its variables are valid against the server schema."""
gql_client = Client(schema=server_schema)

def validator(query_str: str) -> None:
def validator(query_str: str, variables: Dict[str, Any]) -> None:
assert gql_client.schema is not None

query_doc = gql(query_str)
gql_client.validate(document=query_doc)
serialize_variable_values(
gql_client.schema,
query_doc,
variables,
)

return validator

Expand Down
74 changes: 10 additions & 64 deletions tests/integration/test_sl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pytest
from pytest_subtests import SubTests

from dbtsl import OrderByGroupBy
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.client.asyncio import AsyncSemanticLayerClient
from dbtsl.client.base import ADBC, GRAPHQL
from dbtsl.client.sync import SyncSemanticLayerClient

from ..conftest import Credentials
from ..query_test_cases import TEST_QUERIES
from ..util import maybe_await

BothClients = Union[SyncSemanticLayerClient, AsyncSemanticLayerClient]
Expand Down Expand Up @@ -79,72 +80,17 @@ async def test_client_metadata(subtests: SubTests, client: BothClients) -> None:


@pytest.mark.parametrize("api", [ADBC, GRAPHQL])
async def test_client_query_adhoc(api: str, client: BothClients) -> None:
@pytest.mark.parametrize("query", TEST_QUERIES)
async def test_client_query(api: str, query: QueryParameters, client: BothClients) -> None:
client._method_map["query"] = api # type: ignore

metrics = await maybe_await(client.metrics())
assert len(metrics) > 0
table = await maybe_await(
client.query(
metrics=[metrics[0].name],
group_by=["metric_time"],
order_by=["metric_time"],
where=["1=1"],
limit=1,
read_cache=True,
)
)
assert len(table) > 0


@pytest.mark.parametrize("api", [ADBC, GRAPHQL])
async def test_client_query_saved_query(api: str, client: BothClients) -> None:
client._method_map["query"] = api # type: ignore

metrics = await maybe_await(client.metrics())
assert len(metrics) > 0
table = await maybe_await(
client.query(
saved_query="order_metrics",
order_by=[OrderByGroupBy(name="metric_time", grain=None)],
where=["1=1"],
limit=1,
read_cache=True,
)
)
# TODO: fix typing on client.query
table = await maybe_await(client.query(**query)) # type: ignore
assert len(table) > 0


async def test_client_compile_sql_adhoc_query(client: BothClients) -> None:
metrics = await maybe_await(client.metrics())
assert len(metrics) > 0

sql = await maybe_await(
client.compile_sql(
metrics=[metrics[0].name],
group_by=[metrics[0].dimensions[0].name],
order_by=[metrics[0].dimensions[0].name],
where=["1=1"],
limit=1,
read_cache=True,
)
)
assert len(sql) > 0
assert "SELECT" in sql


async def test_client_compile_sql_saved_query(client: BothClients) -> None:
metrics = await maybe_await(client.metrics())
assert len(metrics) > 0

sql = await maybe_await(
client.compile_sql(
saved_query="order_metrics",
order_by=[OrderByGroupBy(name="metric_time", grain=None)],
where=["1=1"],
limit=1,
read_cache=True,
)
)
@pytest.mark.parametrize("query", TEST_QUERIES)
async def test_client_compile_sql_adhoc_query(query: QueryParameters, client: BothClients) -> None:
# TODO: fix typing on client.compile_sql
sql = await maybe_await(client.compile_sql(**query)) # type: ignore
assert len(sql) > 0
assert "SELECT" in sql
37 changes: 37 additions & 0 deletions tests/query_test_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List

from dbtsl import OrderByGroupBy
from dbtsl.api.shared.query_params import QueryParameters

TEST_QUERIES: List[QueryParameters] = [
# ad hoc query, all parameters
{
"metrics": ["order_total"],
"group_by": ["customer__customer_type"],
"order_by": ["customer__customer_type"],
"where": ["1=1"],
"limit": 1,
"read_cache": True,
},
# ad hoc query, only metric
{
"metrics": ["order_total"],
},
# ad hoc query, metric and group by
{
"metrics": ["order_total"],
"group_by": ["customer__customer_type"],
},
# saved query, all parameters
{
"saved_query": "order_metrics",
"order_by": [OrderByGroupBy(name="metric_time", grain="day")],
"where": ["1=1"],
"limit": 1,
"read_cache": True,
},
# saved query, no parameters
{
"saved_query": "order_metrics",
},
]
11 changes: 10 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,22 @@ def test_validate_query_params_adhoc_query_no_metrics() -> None:
validate_query_parameters(p)


def test_validate_query_params_saved_query_group_by() -> None:
p: QueryParameters = {
"saved_query": "sq",
"group_by": ["a", "b"],
}
with pytest.raises(ValueError):
validate_query_parameters(p)


def test_validate_query_params_adhoc_and_saved_query() -> None:
p: QueryParameters = {"metrics": ["a", "b"], "group_by": ["a", "b"], "saved_query": "a"}
with pytest.raises(ValueError):
validate_query_parameters(p)


def test_validate_query_params_no_query() -> None:
p: QueryParameters = {"limit": 1, "where": ["1=1"], "order_by": ["a"], "read_cache": False}
p: QueryParameters = {"group_by": ["gb"], "limit": 1, "where": ["1=1"], "order_by": ["a"], "read_cache": False}
with pytest.raises(ValueError):
validate_query_parameters(p)

0 comments on commit 6ed8f0b

Please sign in to comment.