diff --git a/src/preset_cli/cli/superset/sync/dbt/command.py b/src/preset_cli/cli/superset/sync/dbt/command.py index f9b7a1e8..3f1c8253 100644 --- a/src/preset_cli/cli/superset/sync/dbt/command.py +++ b/src/preset_cli/cli/superset/sync/dbt/command.py @@ -3,10 +3,11 @@ """ import os.path +import subprocess import sys import warnings from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import click import yaml @@ -17,6 +18,7 @@ JobSchema, MetricSchema, MFMetricWithSQLSchema, + MFSQLEngine, ModelSchema, ) from preset_cli.api.clients.superset import SupersetClient @@ -137,6 +139,9 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-branches, too-many if profiles is None: profiles = os.path.expanduser("~/.dbt/profiles.yml") + with open(profiles, encoding="utf-8") as inp: + config = yaml.safe_load(inp) + dialect = MFSQLEngine(config[project]["outputs"][target]["type"].upper()) file_path = Path(file) @@ -200,14 +205,18 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-branches, too-many ] else: og_metrics = [] + sl_metrics = [] metric_schema = MetricSchema() for config in configs["metrics"].values(): - # conform to the same schema that dbt Cloud uses for metrics - config["dependsOn"] = config.pop("depends_on")["nodes"] - config["uniqueId"] = config.pop("unique_id") - og_metrics.append(metric_schema.load(config)) + if "calculation_method" in config or "sql" in config: + # conform to the same schema that dbt Cloud uses for metrics + config["dependsOn"] = config.pop("depends_on")["nodes"] + config["uniqueId"] = config.pop("unique_id") + og_metrics.append(metric_schema.load(config)) + elif sl_metric := get_sl_metric(config, model_map, dialect): + sl_metrics.append(sl_metric) - superset_metrics = get_superset_metrics_per_model(og_metrics) + superset_metrics = get_superset_metrics_per_model(og_metrics, sl_metrics) try: database = sync_database( @@ -338,7 +347,44 @@ def get_job( raise ValueError(f"Job {job_id} not available") -def process_sl_metrics( +def get_sl_metric( + metric: Dict[str, Any], + model_map: Dict[ModelKey, ModelSchema], + dialect: MFSQLEngine, +) -> Optional[MFMetricWithSQLSchema]: + """ + Compute a SL metric using the ``mf`` CLI. + """ + mf_metric_schema = MFMetricWithSQLSchema() + + command = ["mf", "query", "--explain", "--metrics", metric["name"]] + try: + result = subprocess.run(command, capture_output=True, text=True, check=True) + except subprocess.CalledProcessError: + return None + + output = result.stdout.strip() + start = output.find("SELECT") + sql = output[start:] + + models = get_models_from_sql(sql, dialect, model_map) + if len(models) > 1: + return None + model = models[0] + + return mf_metric_schema.load( + { + "name": metric["name"], + "type": metric["type"], + "description": metric["description"], + "sql": sql, + "dialect": dialect.value, + "model": model["unique_id"], + }, + ) + + +def fetch_sl_metrics( dbt_client: DBTClient, environment_id: int, model_map: Dict[ModelKey, ModelSchema], @@ -498,7 +544,7 @@ def dbt_cloud( # pylint: disable=too-many-arguments, too-many-locals model_map = {ModelKey(model["schema"], model["name"]): model for model in models} og_metrics = dbt_client.get_og_metrics(job["id"]) - sl_metrics = process_sl_metrics(dbt_client, job["environment_id"], model_map) + sl_metrics = fetch_sl_metrics(dbt_client, job["environment_id"], model_map) superset_metrics = get_superset_metrics_per_model(og_metrics, sl_metrics) if exposures_only: diff --git a/src/preset_cli/cli/superset/sync/dbt/metrics.py b/src/preset_cli/cli/superset/sync/dbt/metrics.py index 79f0be97..a9cc160f 100644 --- a/src/preset_cli/cli/superset/sync/dbt/metrics.py +++ b/src/preset_cli/cli/superset/sync/dbt/metrics.py @@ -51,10 +51,12 @@ def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> s # dbt >= 1.3 type_ = metric["calculation_method"] sql = metric["expression"] - else: + elif "sql" in metric: # dbt < 1.3 type_ = metric["type"] sql = metric["sql"] + else: + raise Exception(f"Unable to generate metric expression from: {metric}") if metric.get("filters"): sql = apply_filters(sql, metric["filters"])