Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: saved queries in query and compile_sql #45

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20240920-201139.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: Allow saved queries in `query` and `compile_sql`
time: 2024-09-20T20:11:39.216931+02:00
3 changes: 3 additions & 0 deletions .changes/unreleased/Under the Hood-20240920-201151.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Under the Hood
body: Client-side validation of query parameters
time: 2024-09-20T20:11:51.575942+02:00
6 changes: 2 additions & 4 deletions dbtsl/api/adbc/protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import json
from typing import Any, FrozenSet, Mapping

from dbtsl.api.shared.query_params import (
DimensionValuesQueryParameters,
QueryParameters,
)
from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters, validate_query_parameters


class ADBCProtocol:
Expand Down Expand Up @@ -36,6 +33,7 @@ def append_param_if_exists(p_str: str, p_name: str) -> str:
@classmethod
def get_query_sql(cls, params: QueryParameters) -> str:
"""Get the SQL that will be sent via Arrow Flight to the server based on query parameters."""
validate_query_parameters(params)
serialized_params = cls._serialize_params_dict(params, QueryParameters.__optional_keys__)
return f"SELECT * FROM {{{{ semantic_layer.query({serialized_params}) }}}}"

Expand Down
44 changes: 42 additions & 2 deletions dbtsl/api/graphql/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from contextlib import AbstractAsyncContextManager
from typing import List, Optional, Self

import pyarrow as pa
from typing_extensions import AsyncIterator, Unpack
from typing_extensions import AsyncIterator, Unpack, overload

from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.models import (
Expand Down Expand Up @@ -44,10 +44,50 @@ class AsyncGraphQLClient:
"""Get a list of all available saved queries."""
...

async def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
@overload
async def compile_sql(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
@overload
async def compile_sql(
self,
saved_query: str,
group_by: Optional[List[str]] = None,
serramatutu marked this conversation as resolved.
Show resolved Hide resolved
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

@overload
async def query(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
@overload
async def query(
self,
saved_query: str,
group_by: Optional[List[str]] = None,
serramatutu marked this conversation as resolved.
Show resolved Hide resolved
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...
44 changes: 41 additions & 3 deletions dbtsl/api/graphql/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from contextlib import AbstractContextManager
from typing import Iterator, List, Optional

import pyarrow as pa
from typing_extensions import Self, Unpack
from typing_extensions import Self, Unpack, overload

from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.models import (
Expand Down Expand Up @@ -44,10 +44,48 @@ class SyncGraphQLClient:
"""Get a list of all available saved queries."""
...

def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
@overload
def compile_sql(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
@overload
def compile_sql(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
@overload
def query(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
@overload
def query(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...
28 changes: 19 additions & 9 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import NotRequired, override

from dbtsl.api.graphql.util import render_query
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters
from dbtsl.models import Dimension, Entity, Measure, Metric
from dbtsl.models.query import QueryId, QueryResult, QueryStatus
from dbtsl.models.saved_query import SavedQuery
Expand Down Expand Up @@ -200,15 +200,17 @@ def get_request_text(self) -> str:
query = """
mutation createQuery(
$environmentId: BigInt!,
$metrics: [MetricInput!]!,
$groupBy: [GroupByInput!]!,
$savedQuery: String,
$metrics: [MetricInput!],
$groupBy: [GroupByInput!],
$where: [WhereInput!]!,
$orderBy: [OrderByInput!]!,
$limit: Int,
$readCache: Boolean,
) {
createQuery(
environmentId: $environmentId,
savedQuery: $savedQuery,
metrics: $metrics,
groupBy: $groupBy,
where: $where,
Expand All @@ -224,10 +226,13 @@ def get_request_text(self) -> str:

@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
# TODO: fix typing
validate_query_parameters(kwargs) # type: ignore
return {
"environmentId": environment_id,
"metrics": [{"name": m} for m in kwargs.get("metrics", [])],
"groupBy": [{"name": g} for g in kwargs.get("group_by", [])],
"savedQuery": kwargs.get("saved_query", None),
"metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None,
"groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None,
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
Expand Down Expand Up @@ -285,15 +290,17 @@ def get_request_text(self) -> str:
query = """
mutation compileSql(
$environmentId: BigInt!,
$metrics: [MetricInput!]!,
$groupBy: [GroupByInput!]!,
$savedQuery: String,
$metrics: [MetricInput!],
$groupBy: [GroupByInput!],
$where: [WhereInput!]!,
$orderBy: [OrderByInput!]!,
$limit: Int,
$readCache: Boolean,
) {
compileSql(
environmentId: $environmentId,
savedQuery: $savedQuery,
metrics: $metrics,
groupBy: $groupBy,
where: $where,
Expand All @@ -309,10 +316,13 @@ def get_request_text(self) -> str:

@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
# TODO: fix typing
validate_query_parameters(kwargs) # type: ignore
return {
"environmentId": environment_id,
"metrics": [{"name": m} for m in kwargs.get("metrics", [])],
"groupBy": [{"name": g} for g in kwargs.get("group_by", [])],
"savedQuery": kwargs.get("saved_query", None),
"metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None,
"groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None,
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
Expand Down
24 changes: 23 additions & 1 deletion dbtsl/api/shared/query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@


class QueryParameters(TypedDict, total=False):
"""The parameters of `semantic_layer.query`."""
"""The parameters of `semantic_layer.query`.

metrics/group_by and saved_query are mutually exclusive.
"""

saved_query: str
metrics: List[str]
group_by: List[str]
limit: int
Expand All @@ -12,6 +16,24 @@ class QueryParameters(TypedDict, total=False):
read_cache: bool


def validate_query_parameters(params: QueryParameters) -> None:
"""Validate a dict that should be QueryParameters."""
is_saved_query = "saved_query" in params
is_adhoc_query = "metrics" in params or "group_by" in params
if is_saved_query and is_adhoc_query:
raise ValueError(
"metrics/group_by and saved_query are mutually exclusive, "
"since, by definition, saved queries already include "
"metrics and group_by."
)

if "metrics" in params and len(params["metrics"]) == 0:
raise ValueError("You need to specify at least one metric.")

if "group_by" in params and len(params["group_by"]) == 0:
raise ValueError("You need to specify at least one dimension to group by.")


class DimensionValuesQueryParameters(TypedDict, total=False):
"""The parameters of `semantic_layer.dimension_values`."""

Expand Down
48 changes: 44 additions & 4 deletions dbtsl/client/asyncio.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from contextlib import AbstractAsyncContextManager
from typing import AsyncIterator, List
from typing import AsyncIterator, List, Optional

import pyarrow as pa
from typing_extensions import Self, Unpack
from typing_extensions import Self, Unpack, overload

from dbtsl.api.adbc.protocol import QueryParameters
from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery
Expand All @@ -14,12 +14,52 @@ class AsyncSemanticLayerClient:
auth_token: str,
host: str,
) -> None: ...
@overload
async def compile_sql(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
@overload
async def compile_sql(
self,
saved_query: str,
group_by: Optional[List[str]] = None,
serramatutu marked this conversation as resolved.
Show resolved Hide resolved
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

async def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer for a metric data."""
@overload
async def query(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
@overload
async def query(
self,
saved_query: str,
group_by: Optional[List[str]] = None,
serramatutu marked this conversation as resolved.
Show resolved Hide resolved
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...

async def metrics(self) -> List[Metric]:
Expand Down
46 changes: 42 additions & 4 deletions dbtsl/client/sync.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from contextlib import AbstractContextManager
from typing import Iterator, List
from typing import Iterator, List, Optional

import pyarrow as pa
from typing_extensions import Self, Unpack
from typing_extensions import Self, Unpack, overload

from dbtsl.api.adbc.protocol import QueryParameters
from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery
Expand All @@ -14,12 +14,50 @@ class SyncSemanticLayerClient:
auth_token: str,
host: str,
) -> None: ...
@overload
def compile_sql(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
@overload
def compile_sql(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer for a metric data."""
@overload
def query(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
@overload
def query(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...

def metrics(self) -> List[Metric]:
Expand Down
Loading