Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: project governance llm checks #45

Merged
merged 13 commits into from
Dec 30, 2024
13 changes: 13 additions & 0 deletions src/datapilot/clients/altimate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,16 @@ def validate_upload_to_integration(self):
def start_dbt_ingestion(self, params=None):
endpoint = "/dbt/v1/start_dbt_ingestion"
return self.post(endpoint, data=params)

def get_project_governance_llm_checks(self, params=None):
endpoint = "/project_governance/checks"
return self.get(endpoint, params=params)

def run_project_governance_llm_checks(self, manifest, catalog, check_names):
endpoint = "/project_governance/run_checks"
data = {
"manifest": manifest,
"catalog": catalog,
"check_names": check_names,
}
return self.post(endpoint, data=data)
21 changes: 21 additions & 0 deletions src/datapilot/clients/altimate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,24 @@ def start_dbt_ingestion(api_token, tenant, dbt_core_integration_id, dbt_core_int
"ok": False,
"message": "Error starting dbt ingestion worker. ",
}


def get_project_governance_llm_checks(
api_token,
tenant,
backend_url,
):
api_client = APIClient(api_token=api_token, base_url=backend_url, tenant=tenant)
return api_client.get_project_governance_llm_checks()


def run_project_governance_llm_checks(
api_token,
tenant,
backend_url,
manifest,
catalog,
check_names,
):
api_client = APIClient(api_token=api_token, base_url=backend_url, tenant=tenant)
return api_client.run_project_governance_llm_checks(manifest, catalog, check_names)
33 changes: 31 additions & 2 deletions src/datapilot/core/platforms/dbt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datapilot.clients.altimate.utils import validate_credentials
from datapilot.clients.altimate.utils import validate_permissions
from datapilot.config.config import load_config
from datapilot.core.platforms.dbt.constants import LLM
from datapilot.core.platforms.dbt.constants import MODEL
from datapilot.core.platforms.dbt.constants import PROJECT
from datapilot.core.platforms.dbt.executor import DBTInsightGenerator
Expand All @@ -28,6 +29,8 @@ def dbt():


@dbt.command("project-health")
@click.option("--token", required=False, prompt="API Token", help="Your API token for authentication.")
@click.option("--instance-name", required=False, prompt="Instance Name", help="Your tenant ID.")
@click.option(
"--manifest-path",
required=True,
Expand All @@ -49,7 +52,10 @@ def dbt():
default=None,
help="Selective model testing. Specify one or more models to run tests on.",
)
def project_health(manifest_path, catalog_path, config_path=None, select=None):
@click.option("--backend-url", required=False, help="Altimate's Backend URL", default="https://api.myaltimate.com")
def project_health(
token, instance_name, manifest_path, catalog_path, config_path=None, select=None, backend_url="https://api.myaltimate.com"
):
"""
Validate the DBT project's configuration and structure.
:param manifest_path: Path to the DBT manifest file.
Expand All @@ -62,11 +68,21 @@ def project_health(manifest_path, catalog_path, config_path=None, select=None):
selected_models = select.split(" ")
manifest = load_manifest(manifest_path)
catalog = load_catalog(catalog_path) if catalog_path else None
insight_generator = DBTInsightGenerator(manifest=manifest, catalog=catalog, config=config, selected_models=selected_models)

insight_generator = DBTInsightGenerator(
manifest=manifest,
catalog=catalog,
config=config,
selected_models=selected_models,
token=token,
instance_name=instance_name,
backend_url=backend_url,
)
reports = insight_generator.run()

package_insights = reports[PROJECT]
model_insights = reports[MODEL]
llm_insights = reports[LLM]
model_report = generate_model_insights_table(model_insights)
if len(model_report) > 0:
click.echo("--" * 50)
Expand All @@ -85,6 +101,19 @@ def project_health(manifest_path, catalog_path, config_path=None, select=None):
click.echo("--" * 50)
click.echo(tabulate_data(project_report, headers="keys"))

if len(llm_insights):
click.echo("--" * 50)
click.echo("Project Governance LLM Insights")
click.echo("--" * 50)
for check in llm_insights:
click.echo(f"Check: {check['name']}")
for answer in check["answer"]:
click.echo(f"Rule: {answer['Rule']}")
click.echo(f"Location: {answer['Location']}")
click.echo(f"Issue: {answer['Issue']}")
click.echo(f"Fix: {answer['Fix']}")
click.echo("\n")


@dbt.command("onboard")
@click.option("--token", prompt="API Token", help="Your API token for authentication.")
Expand Down
2 changes: 2 additions & 0 deletions src/datapilot/core/platforms/dbt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
MODEL = "model"
SOURCE = "source"

LLM = "llm"


PROJECT = "project"
SQL = "sql"
Expand Down
23 changes: 23 additions & 0 deletions src/datapilot/core/platforms/dbt/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import List
from typing import Optional

from datapilot.clients.altimate.utils import get_project_governance_llm_checks
from datapilot.clients.altimate.utils import run_project_governance_llm_checks
from datapilot.core.platforms.dbt.constants import LLM
from datapilot.core.platforms.dbt.constants import MODEL
from datapilot.core.platforms.dbt.constants import PROJECT
from datapilot.core.platforms.dbt.exceptions import AltimateCLIArgumentError
Expand All @@ -29,11 +32,17 @@ def __init__(
target: str = "dev",
selected_models: Optional[str] = None,
selected_model_ids: Optional[List[str]] = None,
token: Optional[str] = None,
instance_name: Optional[str] = None,
backend_url: Optional[str] = None,
):
self.run_results_path = run_results_path
self.target = target
self.env = env
self.config = config or {}
self.token = token
self.instance_name = instance_name
self.backend_url = backend_url

self.manifest_wrapper = DBTFactory.get_manifest_wrapper(manifest)
self.manifest_present = True
Expand Down Expand Up @@ -85,10 +94,19 @@ def _check_if_skipped(self, insight):
return True
return False

def run_llm_checks(self):
llm_checks = get_project_governance_llm_checks(self.token, self.instance_name, self.backend_url)
check_names = [check["name"] for check in llm_checks if check["alias"] not in self.config.get("disabled_insights", [])]
llm_check_results = run_project_governance_llm_checks(
self.token, self.instance_name, self.backend_url, self.manifest, self.catalog, check_names
)
return llm_check_results

def run(self):
reports = {
MODEL: {},
PROJECT: [],
LLM: [],
}
for insight_class in INSIGHTS:
# TODO: Skip insight based on config
Expand Down Expand Up @@ -154,4 +172,9 @@ def run(self):
else:
self.logger.info(color_text(f"Skipping insight {insight_class.NAME} as {message}", YELLOW))

if self.token and self.instance_name and self.backend_url:
llm_check_results = self.run_llm_checks()
if llm_check_results:
reports[LLM].extend(llm_check_results["results"])

return reports
Loading