Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add compile_sql via GraphQL #44

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20240920-180550.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: '`compile_sql` method for getting the compiled SQL of a query'
time: 2024-09-20T18:05:50.976574+02:00
8 changes: 7 additions & 1 deletion dbtsl/api/graphql/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,10 @@ class AsyncGraphQLClient:
"""Get a list of all available saved queries."""
...

async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
async def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...
8 changes: 7 additions & 1 deletion dbtsl/api/graphql/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,10 @@ class SyncGraphQLClient:
"""Get a list of all available saved queries."""
...

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table": ...
def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...
48 changes: 48 additions & 0 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,53 @@ def parse_response(self, data: Dict[str, Any]) -> QueryResult:
return decode_to_dataclass(data["query"], QueryResult)


class CompileSqlOperation(ProtocolOperation[QueryParameters, str]):
"""Get the compiled SQL that would be sent to the warehouse by a query."""

@override
def get_request_text(self) -> str:
query = """
mutation compileSql(
$environmentId: BigInt!,
$metrics: [MetricInput!]!,
$groupBy: [GroupByInput!]!,
$where: [WhereInput!]!,
$orderBy: [OrderByInput!]!,
$limit: Int,
$readCache: Boolean,
) {
compileSql(
environmentId: $environmentId,
metrics: $metrics,
groupBy: $groupBy,
where: $where,
orderBy: $orderBy,
limit: $limit,
readCache: $readCache,
) {
sql
}
}
"""
return query

@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
return {
"environmentId": environment_id,
"metrics": [{"name": m} for m in kwargs.get("metrics", [])],
"groupBy": [{"name": g} for g in kwargs.get("group_by", [])],
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
"readCache": kwargs.get("read_cache", True),
}

@override
def parse_response(self, data: Dict[str, Any]) -> str:
return data["compileSql"]["sql"]


class GraphQLProtocol:
"""Holds the GraphQL implementation for each of method in the API.

Expand All @@ -291,3 +338,4 @@ class GraphQLProtocol:
saved_queries = ListSavedQueriesOperation()
create_query = CreateQueryOperation()
get_query_result = GetQueryResultOperation()
compile_sql = CompileSqlOperation()
4 changes: 4 additions & 0 deletions dbtsl/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class AsyncSemanticLayerClient:
auth_token: str,
host: str,
) -> None: ...
async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

async def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer for a metric data."""
...
Expand Down
3 changes: 2 additions & 1 deletion dbtsl/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ class BaseSemanticLayerClient(ABC, Generic[TGQLClient, TADBCClient]):
"""

_METHOD_MAP = {
"compile_sql": GRAPHQL,
"dimension_values": ADBC,
"query": ADBC,
"dimensions": GRAPHQL,
"entities": GRAPHQL,
"measures": GRAPHQL,
"metrics": GRAPHQL,
"query": ADBC,
"saved_queries": GRAPHQL,
}

Expand Down
4 changes: 4 additions & 0 deletions dbtsl/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class SyncSemanticLayerClient:
auth_token: str,
host: str,
) -> None: ...
def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer for a metric data."""
...
Expand Down
41 changes: 41 additions & 0 deletions examples/compile_query_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Compile a query and display the SQL."""

from argparse import ArgumentParser

from dbtsl import SemanticLayerClient


def get_arg_parser() -> ArgumentParser:
p = ArgumentParser()

p.add_argument("metric", help="The metric to fetch")
p.add_argument("group_by", help="A dimension to group by")
p.add_argument("--env-id", required=True, help="The dbt environment ID", type=int)
p.add_argument("--token", required=True, help="The API auth token")
p.add_argument("--host", required=True, help="The API host")

return p


def main() -> None:
arg_parser = get_arg_parser()
args = arg_parser.parse_args()

client = SemanticLayerClient(
environment_id=args.env_id,
auth_token=args.token,
host=args.host,
)

with client.session():
sql = client.compile_sql(
metrics=[args.metric],
group_by=[args.group_by],
limit=15,
)
print(f"Compiled SQL for {args.metric} grouped by {args.group_by}, limit 15:")
print(sql)


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions tests/integration/test_sl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,18 @@ async def test_client_query_works(api: str, client: BothClients) -> None:
)
)
assert len(table) > 0


async def test_client_compile_sql_works(client: BothClients) -> None:
metrics = await maybe_await(client.metrics())
assert len(metrics) > 0

sql = await maybe_await(
client.compile_sql(
metrics=[metrics[0].name],
group_by=[metrics[0].dimensions[0].name],
limit=1,
)
)
assert len(sql) > 0
assert "SELECT" in sql