Skip to content

Commit

Permalink
feat: add compile_sql via GraphQL (#44)
Browse files Browse the repository at this point in the history
* feat: add `compile_sql` via GraphQL

This commit adds a new `compile_sql` method to the SDK, which uses the
GraphQL API to generate the compiled SQL given a set of query parameters.

* test: add integration test for `compile_sql`

Added an integration test for `compile_sql` that ensures we get data
back from the API for a valid query which contains a `SELECT` statement.

* docs: add example for `compile_sql`

This commit adds a usage example for `compile_sql`.

* docs: add changelog entry
  • Loading branch information
serramatutu authored Sep 20, 2024
1 parent 13ebfe5 commit 39b6f91
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20240920-180550.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: '`compile_sql` method for getting the compiled SQL of a query'
time: 2024-09-20T18:05:50.976574+02:00
8 changes: 7 additions & 1 deletion dbtsl/api/graphql/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,10 @@ class AsyncGraphQLClient:
"""Get a list of all available saved queries."""
...

async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
async def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...
8 changes: 7 additions & 1 deletion dbtsl/api/graphql/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,10 @@ class SyncGraphQLClient:
"""Get a list of all available saved queries."""
...

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...
48 changes: 48 additions & 0 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,53 @@ def parse_response(self, data: Dict[str, Any]) -> QueryResult:
return decode_to_dataclass(data["query"], QueryResult)


class CompileSqlOperation(ProtocolOperation[QueryParameters, str]):
"""Get the compiled SQL that would be sent to the warehouse by a query."""

@override
def get_request_text(self) -> str:
query = """
mutation compileSql(
$environmentId: BigInt!,
$metrics: [MetricInput!]!,
$groupBy: [GroupByInput!]!,
$where: [WhereInput!]!,
$orderBy: [OrderByInput!]!,
$limit: Int,
$readCache: Boolean,
) {
compileSql(
environmentId: $environmentId,
metrics: $metrics,
groupBy: $groupBy,
where: $where,
orderBy: $orderBy,
limit: $limit,
readCache: $readCache,
) {
sql
}
}
"""
return query

@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
return {
"environmentId": environment_id,
"metrics": [{"name": m} for m in kwargs.get("metrics", [])],
"groupBy": [{"name": g} for g in kwargs.get("group_by", [])],
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
"readCache": kwargs.get("read_cache", True),
}

@override
def parse_response(self, data: Dict[str, Any]) -> str:
return data["compileSql"]["sql"]


class GraphQLProtocol:
"""Holds the GraphQL implementation for each of method in the API.
Expand All @@ -291,3 +338,4 @@ class GraphQLProtocol:
saved_queries = ListSavedQueriesOperation()
create_query = CreateQueryOperation()
get_query_result = GetQueryResultOperation()
compile_sql = CompileSqlOperation()
4 changes: 4 additions & 0 deletions dbtsl/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class AsyncSemanticLayerClient:
auth_token: str,
host: str,
) -> None: ...
async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

async def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer for a metric data."""
...
Expand Down
3 changes: 2 additions & 1 deletion dbtsl/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ class BaseSemanticLayerClient(ABC, Generic[TGQLClient, TADBCClient]):
"""

_METHOD_MAP = {
"compile_sql": GRAPHQL,
"dimension_values": ADBC,
"query": ADBC,
"dimensions": GRAPHQL,
"entities": GRAPHQL,
"measures": GRAPHQL,
"metrics": GRAPHQL,
"query": ADBC,
"saved_queries": GRAPHQL,
}

Expand Down
4 changes: 4 additions & 0 deletions dbtsl/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class SyncSemanticLayerClient:
auth_token: str,
host: str,
) -> None: ...
def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer for a metric data."""
...
Expand Down
41 changes: 41 additions & 0 deletions examples/compile_query_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Compile a query and display the SQL."""

from argparse import ArgumentParser

from dbtsl import SemanticLayerClient


def get_arg_parser() -> ArgumentParser:
p = ArgumentParser()

p.add_argument("metric", help="The metric to fetch")
p.add_argument("group_by", help="A dimension to group by")
p.add_argument("--env-id", required=True, help="The dbt environment ID", type=int)
p.add_argument("--token", required=True, help="The API auth token")
p.add_argument("--host", required=True, help="The API host")

return p


def main() -> None:
arg_parser = get_arg_parser()
args = arg_parser.parse_args()

client = SemanticLayerClient(
environment_id=args.env_id,
auth_token=args.token,
host=args.host,
)

with client.session():
sql = client.compile_sql(
metrics=[args.metric],
group_by=[args.group_by],
limit=15,
)
print(f"Compiled SQL for {args.metric} grouped by {args.group_by}, limit 15:")
print(sql)


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions tests/integration/test_sl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,18 @@ async def test_client_query_works(api: str, client: BothClients) -> None:
)
)
assert len(table) > 0


async def test_client_compile_sql_works(client: BothClients) -> None:
metrics = await maybe_await(client.metrics())
assert len(metrics) > 0

sql = await maybe_await(
client.compile_sql(
metrics=[metrics[0].name],
group_by=[metrics[0].dimensions[0].name],
limit=1,
)
)
assert len(sql) > 0
assert "SELECT" in sql

0 comments on commit 39b6f91

Please sign in to comment.