diff --git a/.changes/unreleased/Features-20240920-180550.yaml b/.changes/unreleased/Features-20240920-180550.yaml new file mode 100644 index 0000000..78cc925 --- /dev/null +++ b/.changes/unreleased/Features-20240920-180550.yaml @@ -0,0 +1,3 @@ +kind: Features +body: '`compile_sql` method for getting the compiled SQL of a query' +time: 2024-09-20T18:05:50.976574+02:00 diff --git a/dbtsl/api/graphql/client/asyncio.pyi b/dbtsl/api/graphql/client/asyncio.pyi index 37464e0..ab9b2e4 100644 --- a/dbtsl/api/graphql/client/asyncio.pyi +++ b/dbtsl/api/graphql/client/asyncio.pyi @@ -44,4 +44,10 @@ class AsyncGraphQLClient: """Get a list of all available saved queries.""" ... - async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ... + async def compile_sql(self, **params: Unpack[QueryParameters]) -> str: + """Get the compiled SQL that would be sent to the warehouse by a query.""" + ... + + async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": + """Query the Semantic Layer.""" + ... diff --git a/dbtsl/api/graphql/client/sync.pyi b/dbtsl/api/graphql/client/sync.pyi index f37bb14..062470c 100644 --- a/dbtsl/api/graphql/client/sync.pyi +++ b/dbtsl/api/graphql/client/sync.pyi @@ -44,4 +44,10 @@ class SyncGraphQLClient: """Get a list of all available saved queries.""" ... - def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ... + def compile_sql(self, **params: Unpack[QueryParameters]) -> str: + """Get the compiled SQL that would be sent to the warehouse by a query.""" + ... + + def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": + """Query the Semantic Layer.""" + ... diff --git a/dbtsl/api/graphql/protocol.py b/dbtsl/api/graphql/protocol.py index 2a56408..a7fea2c 100644 --- a/dbtsl/api/graphql/protocol.py +++ b/dbtsl/api/graphql/protocol.py @@ -277,6 +277,53 @@ def parse_response(self, data: Dict[str, Any]) -> QueryResult: return decode_to_dataclass(data["query"], QueryResult) +class CompileSqlOperation(ProtocolOperation[QueryParameters, str]): + """Get the compiled SQL that would be sent to the warehouse by a query.""" + + @override + def get_request_text(self) -> str: + query = """ + mutation compileSql( + $environmentId: BigInt!, + $metrics: [MetricInput!]!, + $groupBy: [GroupByInput!]!, + $where: [WhereInput!]!, + $orderBy: [OrderByInput!]!, + $limit: Int, + $readCache: Boolean, + ) { + compileSql( + environmentId: $environmentId, + metrics: $metrics, + groupBy: $groupBy, + where: $where, + orderBy: $orderBy, + limit: $limit, + readCache: $readCache, + ) { + sql + } + } + """ + return query + + @override + def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]: + return { + "environmentId": environment_id, + "metrics": [{"name": m} for m in kwargs.get("metrics", [])], + "groupBy": [{"name": g} for g in kwargs.get("group_by", [])], + "where": [{"sql": sql} for sql in kwargs.get("where", [])], + "orderBy": [{"name": o} for o in kwargs.get("order_by", [])], + "limit": kwargs.get("limit", None), + "readCache": kwargs.get("read_cache", True), + } + + @override + def parse_response(self, data: Dict[str, Any]) -> str: + return data["compileSql"]["sql"] + + class GraphQLProtocol: """Holds the GraphQL implementation for each of method in the API. @@ -291,3 +338,4 @@ class GraphQLProtocol: saved_queries = ListSavedQueriesOperation() create_query = CreateQueryOperation() get_query_result = GetQueryResultOperation() + compile_sql = CompileSqlOperation() diff --git a/dbtsl/client/asyncio.pyi b/dbtsl/client/asyncio.pyi index 720fde1..330a0cc 100644 --- a/dbtsl/client/asyncio.pyi +++ b/dbtsl/client/asyncio.pyi @@ -14,6 +14,10 @@ class AsyncSemanticLayerClient: auth_token: str, host: str, ) -> None: ... + async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str: + """Get the compiled SQL that would be sent to the warehouse by a query.""" + ... + async def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table": """Query the Semantic Layer for a metric data.""" ... diff --git a/dbtsl/client/base.py b/dbtsl/client/base.py index 49d1e19..b10b97e 100644 --- a/dbtsl/client/base.py +++ b/dbtsl/client/base.py @@ -21,12 +21,13 @@ class BaseSemanticLayerClient(ABC, Generic[TGQLClient, TADBCClient]): """ _METHOD_MAP = { + "compile_sql": GRAPHQL, "dimension_values": ADBC, - "query": ADBC, "dimensions": GRAPHQL, "entities": GRAPHQL, "measures": GRAPHQL, "metrics": GRAPHQL, + "query": ADBC, "saved_queries": GRAPHQL, } diff --git a/dbtsl/client/sync.pyi b/dbtsl/client/sync.pyi index dc0cebf..ec1c511 100644 --- a/dbtsl/client/sync.pyi +++ b/dbtsl/client/sync.pyi @@ -14,6 +14,10 @@ class SyncSemanticLayerClient: auth_token: str, host: str, ) -> None: ... + def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str: + """Get the compiled SQL that would be sent to the warehouse by a query.""" + ... + def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table": """Query the Semantic Layer for a metric data.""" ... diff --git a/examples/compile_query_sync.py b/examples/compile_query_sync.py new file mode 100644 index 0000000..048bfdb --- /dev/null +++ b/examples/compile_query_sync.py @@ -0,0 +1,41 @@ +"""Compile a query and display the SQL.""" + +from argparse import ArgumentParser + +from dbtsl import SemanticLayerClient + + +def get_arg_parser() -> ArgumentParser: + p = ArgumentParser() + + p.add_argument("metric", help="The metric to fetch") + p.add_argument("group_by", help="A dimension to group by") + p.add_argument("--env-id", required=True, help="The dbt environment ID", type=int) + p.add_argument("--token", required=True, help="The API auth token") + p.add_argument("--host", required=True, help="The API host") + + return p + + +def main() -> None: + arg_parser = get_arg_parser() + args = arg_parser.parse_args() + + client = SemanticLayerClient( + environment_id=args.env_id, + auth_token=args.token, + host=args.host, + ) + + with client.session(): + sql = client.compile_sql( + metrics=[args.metric], + group_by=[args.group_by], + limit=15, + ) + print(f"Compiled SQL for {args.metric} grouped by {args.group_by}, limit 15:") + print(sql) + + +if __name__ == "__main__": + main() diff --git a/tests/integration/test_sl_client.py b/tests/integration/test_sl_client.py index 78dcb05..ce76148 100644 --- a/tests/integration/test_sl_client.py +++ b/tests/integration/test_sl_client.py @@ -92,3 +92,18 @@ async def test_client_query_works(api: str, client: BothClients) -> None: ) ) assert len(table) > 0 + + +async def test_client_compile_sql_works(client: BothClients) -> None: + metrics = await maybe_await(client.metrics()) + assert len(metrics) > 0 + + sql = await maybe_await( + client.compile_sql( + metrics=[metrics[0].name], + group_by=[metrics[0].dimensions[0].name], + limit=1, + ) + ) + assert len(sql) > 0 + assert "SELECT" in sql