Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: project governance llm checks #45

Merged
merged 13 commits into from
Dec 30, 2024
13 changes: 13 additions & 0 deletions src/datapilot/clients/altimate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,16 @@ def validate_upload_to_integration(self):
def start_dbt_ingestion(self, params=None):
endpoint = "/dbt/v1/start_dbt_ingestion"
return self.post(endpoint, data=params)

def get_project_governance_llm_checks(self, params=None):
endpoint = "/project_governance/checks"
return self.get(endpoint, params=params)

def run_project_governance_llm_checks(self, manifest, catalog, check_names):
endpoint = "/project_governance/check/run"
data = {
"manifest": manifest,
"catalog": catalog,
"check_names": check_names,
}
return self.post(endpoint, data=data)
21 changes: 21 additions & 0 deletions src/datapilot/clients/altimate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,24 @@ def start_dbt_ingestion(api_token, tenant, dbt_core_integration_id, dbt_core_int
"ok": False,
"message": "Error starting dbt ingestion worker. ",
}


def get_project_governance_llm_checks(
api_token,
tenant,
backend_url,
):
api_client = APIClient(api_token=api_token, base_url=backend_url, tenant=tenant)
return api_client.get_project_governance_llm_checks()


def run_project_governance_llm_checks(
api_token,
tenant,
backend_url,
manifest,
catalog,
check_names,
):
api_client = APIClient(api_token=api_token, base_url=backend_url, tenant=tenant)
return api_client.run_project_governance_llm_checks(manifest, catalog, check_names)
18 changes: 16 additions & 2 deletions src/datapilot/core/platforms/dbt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def dbt():


@dbt.command("project-health")
@click.option("--token", required=False, help="Your API token for authentication.")
@click.option("--instance-name", required=False, help="Your tenant ID.")
@click.option(
"--manifest-path",
required=True,
Expand All @@ -49,7 +51,10 @@ def dbt():
default=None,
help="Selective model testing. Specify one or more models to run tests on.",
)
def project_health(manifest_path, catalog_path, config_path=None, select=None):
@click.option("--backend-url", required=False, help="Altimate's Backend URL", default="https://api.myaltimate.com")
def project_health(
token, instance_name, manifest_path, catalog_path, config_path=None, select=None, backend_url="https://api.myaltimate.com"
):
"""
Validate the DBT project's configuration and structure.
:param manifest_path: Path to the DBT manifest file.
Expand All @@ -62,7 +67,16 @@ def project_health(manifest_path, catalog_path, config_path=None, select=None):
selected_models = select.split(" ")
manifest = load_manifest(manifest_path)
catalog = load_catalog(catalog_path) if catalog_path else None
insight_generator = DBTInsightGenerator(manifest=manifest, catalog=catalog, config=config, selected_models=selected_models)

insight_generator = DBTInsightGenerator(
manifest=manifest,
catalog=catalog,
config=config,
selected_models=selected_models,
token=token,
instance_name=instance_name,
backend_url=backend_url,
)
reports = insight_generator.run()

package_insights = reports[PROJECT]
Expand Down
2 changes: 2 additions & 0 deletions src/datapilot/core/platforms/dbt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
MODEL = "model"
SOURCE = "source"

LLM = "llm"


PROJECT = "project"
SQL = "sql"
Expand Down
67 changes: 67 additions & 0 deletions src/datapilot/core/platforms/dbt/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
from typing import List
from typing import Optional

from datapilot.clients.altimate.utils import get_project_governance_llm_checks
from datapilot.clients.altimate.utils import run_project_governance_llm_checks
from datapilot.core.platforms.dbt.constants import LLM
from datapilot.core.platforms.dbt.constants import MODEL
from datapilot.core.platforms.dbt.constants import PROJECT
from datapilot.core.platforms.dbt.exceptions import AltimateCLIArgumentError
from datapilot.core.platforms.dbt.factory import DBTFactory
from datapilot.core.platforms.dbt.insights import INSIGHTS
from datapilot.core.platforms.dbt.insights.schema import DBTInsightResult
from datapilot.core.platforms.dbt.insights.schema import DBTModelInsightResponse
from datapilot.core.platforms.dbt.schemas.manifest import Catalog
from datapilot.core.platforms.dbt.schemas.manifest import Manifest
from datapilot.core.platforms.dbt.utils import get_models
Expand All @@ -29,11 +34,19 @@ def __init__(
target: str = "dev",
selected_models: Optional[str] = None,
selected_model_ids: Optional[List[str]] = None,
token: Optional[str] = None,
instance_name: Optional[str] = None,
backend_url: Optional[str] = None,
):
self.run_results_path = run_results_path
self.target = target
self.env = env
self.config = config or {}
self.token = token
self.instance_name = instance_name
self.backend_url = backend_url
self.manifest = manifest
self.catalog = catalog

self.manifest_wrapper = DBTFactory.get_manifest_wrapper(manifest)
self.manifest_present = True
Expand Down Expand Up @@ -85,6 +98,22 @@ def _check_if_skipped(self, insight):
return True
return False

def run_llm_checks(self):
llm_checks = get_project_governance_llm_checks(self.token, self.instance_name, self.backend_url)
check_names = [check["name"] for check in llm_checks if check["alias"] not in self.config.get("disabled_insights", [])]
if len(check_names) == 0:
return {"results": []}

llm_check_results = run_project_governance_llm_checks(
self.token,
self.instance_name,
self.backend_url,
self.manifest.json() if self.manifest else "",
self.catalog.json() if self.catalog else "",
check_names,
)
return llm_check_results

def run(self):
reports = {
MODEL: {},
Expand Down Expand Up @@ -154,4 +183,42 @@ def run(self):
else:
self.logger.info(color_text(f"Skipping insight {insight_class.NAME} as {message}", YELLOW))

if self.token and self.instance_name and self.backend_url:
llm_check_results = self.run_llm_checks()
llm_reports = llm_check_results.get("results", [])
llm_insights = {}
for report in llm_reports:
for answer in report["answer"]:
location = answer["unique_id"]
if location not in llm_insights:
llm_insights[location] = []
metadata = answer.get("metadata", {})
metadata["source"] = LLM
metadata["teammate_check_id"] = report["id"]
metadata["category"] = report["type"]
llm_insights[location].append(
DBTModelInsightResponse(
insight=DBTInsightResult(
type="Custom",
name=report["name"],
message=answer["message"],
reason_to_flag=answer["reason_to_flag"],
recommendation=answer["recommendation"],
metadata=metadata,
),
severity=answer["severity"],
path=answer["path"] if answer.get("path") else "",
original_file_path=answer["original_file_path"] if answer.get("original_file_path") else "",
package_name=answer.get["package_name"] if answer.get("package_name") else "",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect syntax for accessing dictionary values. Use parentheses for the get method.

Suggested change
package_name=answer.get["package_name"] if answer.get("package_name") else "",
package_name=answer.get("package_name") if answer.get("package_name") else "",

unique_id=answer["unique_id"],
)
)

if llm_insights:
for key, value in llm_insights.items():
if key in reports[MODEL]:
reports[MODEL][key].extend(value)
else:
reports[MODEL][key] = value

return reports
Loading