diff --git a/.changes/unreleased/Features-20240920-201139.yaml b/.changes/unreleased/Features-20240920-201139.yaml new file mode 100644 index 0000000..265a176 --- /dev/null +++ b/.changes/unreleased/Features-20240920-201139.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Allow saved queries in `query` and `compile_sql` +time: 2024-09-20T20:11:39.216931+02:00 diff --git a/.changes/unreleased/Under the Hood-20240920-201151.yaml b/.changes/unreleased/Under the Hood-20240920-201151.yaml new file mode 100644 index 0000000..7bc5432 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240920-201151.yaml @@ -0,0 +1,3 @@ +kind: Under the Hood +body: Client-side validation of query parameters +time: 2024-09-20T20:11:51.575942+02:00 diff --git a/dbtsl/api/adbc/protocol.py b/dbtsl/api/adbc/protocol.py index 33dbcde..0a89280 100644 --- a/dbtsl/api/adbc/protocol.py +++ b/dbtsl/api/adbc/protocol.py @@ -1,10 +1,7 @@ import json from typing import Any, FrozenSet, Mapping -from dbtsl.api.shared.query_params import ( - DimensionValuesQueryParameters, - QueryParameters, -) +from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters, validate_query_parameters class ADBCProtocol: @@ -36,6 +33,7 @@ def append_param_if_exists(p_str: str, p_name: str) -> str: @classmethod def get_query_sql(cls, params: QueryParameters) -> str: """Get the SQL that will be sent via Arrow Flight to the server based on query parameters.""" + validate_query_parameters(params) serialized_params = cls._serialize_params_dict(params, QueryParameters.__optional_keys__) return f"SELECT * FROM {{{{ semantic_layer.query({serialized_params}) }}}}" diff --git a/dbtsl/api/graphql/client/asyncio.pyi b/dbtsl/api/graphql/client/asyncio.pyi index ab9b2e4..36af1a3 100644 --- a/dbtsl/api/graphql/client/asyncio.pyi +++ b/dbtsl/api/graphql/client/asyncio.pyi @@ -2,7 +2,7 @@ from contextlib import AbstractAsyncContextManager from typing import List, Optional, Self import pyarrow as pa -from typing_extensions import AsyncIterator, Unpack +from typing_extensions import AsyncIterator, Unpack, overload from dbtsl.api.shared.query_params import QueryParameters from dbtsl.models import ( @@ -44,10 +44,48 @@ class AsyncGraphQLClient: """Get a list of all available saved queries.""" ... - async def compile_sql(self, **params: Unpack[QueryParameters]) -> str: + @overload + async def compile_sql( + self, + metrics: List[str], + group_by: Optional[List[str]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> str: ... + @overload + async def compile_sql( + self, + saved_query: str, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> str: ... + async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str: """Get the compiled SQL that would be sent to the warehouse by a query.""" ... + @overload + async def query( + self, + metrics: List[str], + group_by: Optional[List[str]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> "pa.Table": ... + @overload + async def query( + self, + saved_query: str, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> "pa.Table": ... async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": """Query the Semantic Layer.""" ... diff --git a/dbtsl/api/graphql/client/sync.pyi b/dbtsl/api/graphql/client/sync.pyi index 062470c..b050d62 100644 --- a/dbtsl/api/graphql/client/sync.pyi +++ b/dbtsl/api/graphql/client/sync.pyi @@ -2,7 +2,7 @@ from contextlib import AbstractContextManager from typing import Iterator, List, Optional import pyarrow as pa -from typing_extensions import Self, Unpack +from typing_extensions import Self, Unpack, overload from dbtsl.api.shared.query_params import QueryParameters from dbtsl.models import ( @@ -44,10 +44,48 @@ class SyncGraphQLClient: """Get a list of all available saved queries.""" ... - def compile_sql(self, **params: Unpack[QueryParameters]) -> str: + @overload + def compile_sql( + self, + metrics: List[str], + group_by: Optional[List[str]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> str: ... + @overload + def compile_sql( + self, + saved_query: str, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> str: ... + def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str: """Get the compiled SQL that would be sent to the warehouse by a query.""" ... - def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": + @overload + def query( + self, + metrics: List[str], + group_by: Optional[List[str]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> "pa.Table": ... + @overload + def query( + self, + saved_query: str, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> "pa.Table": ... + async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": """Query the Semantic Layer.""" ... diff --git a/dbtsl/api/graphql/protocol.py b/dbtsl/api/graphql/protocol.py index a7fea2c..417fc6f 100644 --- a/dbtsl/api/graphql/protocol.py +++ b/dbtsl/api/graphql/protocol.py @@ -5,7 +5,7 @@ from typing_extensions import NotRequired, override from dbtsl.api.graphql.util import render_query -from dbtsl.api.shared.query_params import QueryParameters +from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters from dbtsl.models import Dimension, Entity, Measure, Metric from dbtsl.models.query import QueryId, QueryResult, QueryStatus from dbtsl.models.saved_query import SavedQuery @@ -200,8 +200,9 @@ def get_request_text(self) -> str: query = """ mutation createQuery( $environmentId: BigInt!, - $metrics: [MetricInput!]!, - $groupBy: [GroupByInput!]!, + $savedQuery: String, + $metrics: [MetricInput!], + $groupBy: [GroupByInput!], $where: [WhereInput!]!, $orderBy: [OrderByInput!]!, $limit: Int, @@ -209,6 +210,7 @@ def get_request_text(self) -> str: ) { createQuery( environmentId: $environmentId, + savedQuery: $savedQuery, metrics: $metrics, groupBy: $groupBy, where: $where, @@ -224,10 +226,13 @@ def get_request_text(self) -> str: @override def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]: + # TODO: fix typing + validate_query_parameters(kwargs) # type: ignore return { "environmentId": environment_id, - "metrics": [{"name": m} for m in kwargs.get("metrics", [])], - "groupBy": [{"name": g} for g in kwargs.get("group_by", [])], + "savedQuery": kwargs.get("saved_query", None), + "metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None, + "groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None, "where": [{"sql": sql} for sql in kwargs.get("where", [])], "orderBy": [{"name": o} for o in kwargs.get("order_by", [])], "limit": kwargs.get("limit", None), @@ -285,8 +290,9 @@ def get_request_text(self) -> str: query = """ mutation compileSql( $environmentId: BigInt!, - $metrics: [MetricInput!]!, - $groupBy: [GroupByInput!]!, + $savedQuery: String, + $metrics: [MetricInput!], + $groupBy: [GroupByInput!], $where: [WhereInput!]!, $orderBy: [OrderByInput!]!, $limit: Int, @@ -294,6 +300,7 @@ def get_request_text(self) -> str: ) { compileSql( environmentId: $environmentId, + savedQuery: $savedQuery, metrics: $metrics, groupBy: $groupBy, where: $where, @@ -309,10 +316,13 @@ def get_request_text(self) -> str: @override def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]: + # TODO: fix typing + validate_query_parameters(kwargs) # type: ignore return { "environmentId": environment_id, - "metrics": [{"name": m} for m in kwargs.get("metrics", [])], - "groupBy": [{"name": g} for g in kwargs.get("group_by", [])], + "savedQuery": kwargs.get("saved_query", None), + "metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None, + "groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None, "where": [{"sql": sql} for sql in kwargs.get("where", [])], "orderBy": [{"name": o} for o in kwargs.get("order_by", [])], "limit": kwargs.get("limit", None), diff --git a/dbtsl/api/shared/query_params.py b/dbtsl/api/shared/query_params.py index 62c4293..5a02016 100644 --- a/dbtsl/api/shared/query_params.py +++ b/dbtsl/api/shared/query_params.py @@ -2,8 +2,12 @@ class QueryParameters(TypedDict, total=False): - """The parameters of `semantic_layer.query`.""" + """The parameters of `semantic_layer.query`. + metrics/group_by and saved_query are mutually exclusive. + """ + + saved_query: str metrics: List[str] group_by: List[str] limit: int @@ -12,6 +16,24 @@ class QueryParameters(TypedDict, total=False): read_cache: bool +def validate_query_parameters(params: QueryParameters) -> None: + """Validate a dict that should be QueryParameters.""" + is_saved_query = "saved_query" in params + is_adhoc_query = "metrics" in params or "group_by" in params + if is_saved_query and is_adhoc_query: + raise ValueError( + "metrics/group_by and saved_query are mutually exclusive, " + "since, by definition, saved queries already include " + "metrics and group_by." + ) + + if "metrics" in params and len(params["metrics"]) == 0: + raise ValueError("You need to specify at least one metric.") + + if "group_by" in params and len(params["group_by"]) == 0: + raise ValueError("You need to specify at least one dimension to group by.") + + class DimensionValuesQueryParameters(TypedDict, total=False): """The parameters of `semantic_layer.dimension_values`.""" diff --git a/dbtsl/client/asyncio.pyi b/dbtsl/client/asyncio.pyi index 330a0cc..cbe9293 100644 --- a/dbtsl/client/asyncio.pyi +++ b/dbtsl/client/asyncio.pyi @@ -1,8 +1,8 @@ from contextlib import AbstractAsyncContextManager -from typing import AsyncIterator, List +from typing import AsyncIterator, List, Optional import pyarrow as pa -from typing_extensions import Self, Unpack +from typing_extensions import Self, Unpack, overload from dbtsl.api.adbc.protocol import QueryParameters from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery @@ -14,12 +14,50 @@ class AsyncSemanticLayerClient: auth_token: str, host: str, ) -> None: ... + @overload + async def compile_sql( + self, + metrics: List[str], + group_by: Optional[List[str]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> str: ... + @overload + async def compile_sql( + self, + saved_query: str, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> str: ... async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str: """Get the compiled SQL that would be sent to the warehouse by a query.""" ... - async def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table": - """Query the Semantic Layer for a metric data.""" + @overload + async def query( + self, + metrics: List[str], + group_by: Optional[List[str]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> "pa.Table": ... + @overload + async def query( + self, + saved_query: str, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> "pa.Table": ... + async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": + """Query the Semantic Layer.""" ... async def metrics(self) -> List[Metric]: diff --git a/dbtsl/client/sync.pyi b/dbtsl/client/sync.pyi index ec1c511..bda4e6d 100644 --- a/dbtsl/client/sync.pyi +++ b/dbtsl/client/sync.pyi @@ -1,8 +1,8 @@ from contextlib import AbstractContextManager -from typing import Iterator, List +from typing import Iterator, List, Optional import pyarrow as pa -from typing_extensions import Self, Unpack +from typing_extensions import Self, Unpack, overload from dbtsl.api.adbc.protocol import QueryParameters from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery @@ -14,12 +14,50 @@ class SyncSemanticLayerClient: auth_token: str, host: str, ) -> None: ... + @overload + def compile_sql( + self, + metrics: List[str], + group_by: Optional[List[str]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> str: ... + @overload + def compile_sql( + self, + saved_query: str, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> str: ... def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str: """Get the compiled SQL that would be sent to the warehouse by a query.""" ... - def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table": - """Query the Semantic Layer for a metric data.""" + @overload + def query( + self, + metrics: List[str], + group_by: Optional[List[str]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> "pa.Table": ... + @overload + def query( + self, + saved_query: str, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + where: Optional[List[str]] = None, + read_cache: bool = True, + ) -> "pa.Table": ... + async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": + """Query the Semantic Layer.""" ... def metrics(self) -> List[Metric]: diff --git a/tests/api/graphql/test_client.py b/tests/api/graphql/test_client.py index d3002b4..7be77eb 100644 --- a/tests/api/graphql/test_client.py +++ b/tests/api/graphql/test_client.py @@ -9,7 +9,6 @@ from dbtsl.api.graphql.client.asyncio import AsyncGraphQLClient from dbtsl.api.graphql.client.sync import SyncGraphQLClient from dbtsl.api.graphql.protocol import GraphQLProtocol, ProtocolOperation -from dbtsl.api.shared.query_params import QueryParameters from dbtsl.models.query import QueryId, QueryResult, QueryStatus # The following 2 tests are copies of each other since testing the same sync/async functionality is @@ -58,7 +57,7 @@ async def run_behavior(op: ProtocolOperation, query_id: QueryId, page_num: int) mocker.patch.object(gql_mock, "__aenter__", new_callable=AsyncMock) mocker.patch("dbtsl.api.graphql.client.asyncio.isinstance", return_value=True) - kwargs: QueryParameters = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1} + kwargs = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1} async with client.session(): result_table = await client.query(**kwargs) @@ -123,7 +122,7 @@ def run_behavior(op: ProtocolOperation, query_id: QueryId, page_num: int) -> Que mocker.patch.object(gql_mock, "__aenter__") mocker.patch("dbtsl.api.graphql.client.sync.isinstance", return_value=True) - kwargs: QueryParameters = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1} + kwargs = {"metrics": ["m1", "m2"], "group_by": ["gb"], "limit": 1} with client.session(): result_table = client.query(**kwargs) diff --git a/tests/api/graphql/test_protocol.py b/tests/api/graphql/test_protocol.py index 36ec5a0..cdcfc2d 100644 --- a/tests/api/graphql/test_protocol.py +++ b/tests/api/graphql/test_protocol.py @@ -8,7 +8,7 @@ def test_queries_are_valid(subtests: SubTests, validate_query: QueryValidator) -> None: """Test all GraphQL queries in `GraphQLProtocol` are valid against the server schema. - This test dynamically iterates over `GraphQLProtocol` sowhenever a new method is + This test dynamically iterates over `GraphQLProtocol` so whenever a new method is added it will get tested automatically. """ prop_names = dir(GraphQLProtocol) diff --git a/tests/integration/test_sl_client.py b/tests/integration/test_sl_client.py index ce76148..985c641 100644 --- a/tests/integration/test_sl_client.py +++ b/tests/integration/test_sl_client.py @@ -50,7 +50,7 @@ async def client( # NOTE: grouping all these tests in one because they depend on each other, i.e # dimensions depends on metrics etc -async def test_client_works_multiple(subtests: SubTests, client: BothClients) -> None: +async def test_client_metadata(subtests: SubTests, client: BothClients) -> None: with subtests.test("metrics"): metrics = await maybe_await(client.metrics()) assert len(metrics) > 0 @@ -79,7 +79,7 @@ async def test_client_lists_saved_queries(client: BothClients) -> None: @pytest.mark.parametrize("api", [ADBC, GRAPHQL]) -async def test_client_query_works(api: str, client: BothClients) -> None: +async def test_client_query_adhoc(api: str, client: BothClients) -> None: client._method_map["query"] = api # type: ignore metrics = await maybe_await(client.metrics()) @@ -94,7 +94,22 @@ async def test_client_query_works(api: str, client: BothClients) -> None: assert len(table) > 0 -async def test_client_compile_sql_works(client: BothClients) -> None: +@pytest.mark.parametrize("api", [ADBC, GRAPHQL]) +async def test_client_query_saved_query(api: str, client: BothClients) -> None: + client._method_map["query"] = api # type: ignore + + metrics = await maybe_await(client.metrics()) + assert len(metrics) > 0 + table = await maybe_await( + client.query( + saved_query="order_metrics", + limit=1, + ) + ) + assert len(table) > 0 + + +async def test_client_compile_sql_adhoc_query(client: BothClients) -> None: metrics = await maybe_await(client.metrics()) assert len(metrics) > 0 @@ -107,3 +122,17 @@ async def test_client_compile_sql_works(client: BothClients) -> None: ) assert len(sql) > 0 assert "SELECT" in sql + + +async def test_client_compile_sql_saved_query(client: BothClients) -> None: + metrics = await maybe_await(client.metrics()) + assert len(metrics) > 0 + + sql = await maybe_await( + client.compile_sql( + saved_query="order_metrics", + limit=1, + ) + ) + assert len(sql) > 0 + assert "SELECT" in sql diff --git a/tests/test_models.py b/tests/test_models.py index 6aab82b..5f6a492 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,9 +1,11 @@ from dataclasses import dataclass from typing import List +import pytest from mashumaro.codecs.basic import decode from dbtsl.api.graphql.util import normalize_query +from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters from dbtsl.models.base import BaseModel, GraphQLFragmentMixin from dbtsl.models.base import snake_case_to_camel_case as stc @@ -77,3 +79,40 @@ class B(BaseModel, GraphQLFragmentMixin): assert b_fragment.name == "fragmentB" assert b_fragment.body == b_expect assert b_fragments[1] == a_fragment + + +def test_validate_query_params_adhoc_query_valid() -> None: + p: QueryParameters = { + "metrics": ["a", "b"], + "group_by": ["a", "b"], + } + validate_query_parameters(p) + + +def test_validate_query_params_saved_query_valid() -> None: + p: QueryParameters = {"saved_query": "a"} + validate_query_parameters(p) + + +def test_validate_query_params_adhoc_query_no_metrics() -> None: + p: QueryParameters = { + "metrics": [], + "group_by": ["a", "b"], + } + with pytest.raises(ValueError): + validate_query_parameters(p) + + +def test_validate_query_params_adhoc_query_no_group_by() -> None: + p: QueryParameters = { + "metrics": ["a", "b"], + "group_by": [], + } + with pytest.raises(ValueError): + validate_query_parameters(p) + + +def test_validate_query_params_adhoc_and_saved_query() -> None: + p: QueryParameters = {"metrics": ["a", "b"], "group_by": ["a", "b"], "saved_query": "a"} + with pytest.raises(ValueError): + validate_query_parameters(p)