Skip to content

Commit

Permalink
fix: bad config error at pipeline level in checks (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
julesbertrand authored Oct 6, 2023
1 parent 53b3372 commit d8dab84
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
36 changes: 25 additions & 11 deletions deployer/pipeline_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from loguru import logger
from pydantic import Field, ValidationError, computed_field, model_validator
from pydantic.functional_validators import ModelWrapValidatorHandler
from pydantic_core import PydanticCustomError
from typing_extensions import Annotated

from deployer.constants import (
Expand All @@ -14,6 +15,7 @@
)
from deployer.pipeline_deployer import VertexPipelineDeployer
from deployer.utils.config import list_config_filepaths, load_config
from deployer.utils.exceptions import BadConfigError
from deployer.utils.logging import disable_logger
from deployer.utils.models import CustomBaseModel, create_model_from_pipeline
from deployer.utils.utils import (
Expand All @@ -26,10 +28,29 @@
PipelineName = make_enum_from_python_package_dir(PIPELINE_ROOT_PATH)


class DynamicConfigsModel(CustomBaseModel, Generic[PipelineConfigT]):
class ConfigDynamicModel(CustomBaseModel, Generic[PipelineConfigT]):
"""Model used to generate checks for configs based on pipeline dynamic model"""

configs: Dict[str, PipelineConfigT]
config_path: Path
config: PipelineConfigT

@model_validator(mode="before")
@classmethod
def load_config_if_empty(cls, data: Any) -> Any:
"""Load config if it is empty"""
if data.get("config") is None:
try:
parameter_values, input_artifacts = load_config(data["config_path"])
except BadConfigError as e:
raise PydanticCustomError("BadConfigError", str(e)) from e
data["config"] = {**(parameter_values or {}), **(input_artifacts or {})}
return data


class ConfigsDynamicModel(CustomBaseModel, Generic[PipelineConfigT]):
"""Model used to generate checks for configs based on pipeline dynamic model"""

configs: Dict[str, ConfigDynamicModel[PipelineConfigT]]


class Pipeline(CustomBaseModel):
Expand All @@ -52,13 +73,6 @@ def pipeline(self) -> Any:
with disable_logger("deployer.utils.utils"):
return import_pipeline_from_dir(PIPELINE_ROOT_PATH, self.pipeline_name.value)

@computed_field()
def configs(self) -> Any:
"""Load configs"""
configs = [load_config(config_path) for config_path in self.config_paths]
configs = [{**(pv or {}), **(ia or {})} for pv, ia in configs]
return configs

@model_validator(mode="after")
def import_pipeline(self):
"""Validate that the pipeline can be imported by calling pipeline computed field"""
Expand Down Expand Up @@ -89,9 +103,9 @@ def validate_configs(self):
"""Validate configs against pipeline parameters definition"""
logger.debug(f"Validating configs for pipeline {self.pipeline_name.value}")
PipelineDynamicModel = create_model_from_pipeline(self.pipeline)
ConfigsModel = DynamicConfigsModel[PipelineDynamicModel]
ConfigsModel = ConfigsDynamicModel[PipelineDynamicModel]
ConfigsModel.model_validate(
{"configs": dict(zip([x.name for x in self.config_paths], self.configs))}
{"configs": {x.name: {"config_path": x} for x in self.config_paths}}
)
return self

Expand Down
6 changes: 3 additions & 3 deletions deployer/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import ValidationError
from pydantic_settings import BaseSettings, SettingsConfigDict

from deployer.utils.exceptions import UnsupportedConfigFileError
from deployer.utils.exceptions import BadConfigError, UnsupportedConfigFileError


class VertexPipelinesSettings(BaseSettings): # noqa: D101
Expand Down Expand Up @@ -119,15 +119,15 @@ def _load_config_python(config_filepath: Path) -> Tuple[Optional[dict], Optional
input_artifacts = getattr(module, "input_artifacts", None)

if parameter_values is None and input_artifacts is None:
raise ValueError(
raise BadConfigError(
f"{config_filepath}: Python config file must contain a `parameter_values` "
"and/or `input_artifacts` dict."
)

if parameter_values is not None and input_artifacts is not None:
common_keys = set(parameter_values.keys()).intersection(set(input_artifacts.keys()))
if common_keys:
raise ValueError(
raise BadConfigError(
f"{config_filepath}: Python config file must not contain common keys in "
"`parameter_values` and `input_artifacts` dict. Common keys: {common_keys}"
)
Expand Down
4 changes: 4 additions & 0 deletions deployer/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class MissingGoogleArtifactRegistryHostError(Exception):

class UnsupportedConfigFileError(Exception):
"""Raised when the config file is not supported."""


class BadConfigError(ValueError):
"""Raised when a config is invalid."""
1 change: 1 addition & 0 deletions deployer/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Literal

import kfp.components.graph_component
import kfp.dsl
from pydantic import BaseModel, ConfigDict, create_model
from typing_extensions import _AnnotatedAlias

Expand Down
13 changes: 5 additions & 8 deletions deployer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def print_pipelines_list(pipelines_dict: Dict[str, list], with_configs: bool = F
console.print(table)


def print_check_results_table(
def print_check_results_table( # noqa: C901
to_check: Dict[str, list], validation_error: Optional[ValidationError] = None
) -> None:
"""This function prints a table of check results to the console.
Expand Down Expand Up @@ -126,7 +126,6 @@ def print_check_results_table(
table.add_row(*row.model_dump().values(), style="bold yellow")

elif len(errors) == 1 and len(errors[0]["loc"]) == 2:
print(errors)
row = ChecksTableRow(
status="❌",
pipeline=pipeline_name,
Expand All @@ -140,19 +139,17 @@ def print_check_results_table(
error_rows = []
for error in errors:
if error["loc"][3] == config_filepath.name:
error_row = {
"type": error["type"],
"attribute": error["loc"][4],
"msg": error["msg"],
}
error_row = {"type": error["type"], "msg": error["msg"]}
if len(error["loc"]) > 4:
error_row["attribute"] = error["loc"][5]
error_rows.append(error_row)
if error_rows:
row = ChecksTableRow(
status="❌",
pipeline=pipeline_name,
config_file=config_filepath.name,
config_error_type="\n".join([er["type"] for er in error_rows]),
attribute="\n".join([er["attribute"] for er in error_rows]),
attribute="\n".join([er.get("attribute", "") for er in error_rows]),
config_error_message="\n".join([er["msg"] for er in error_rows]),
)
table.add_row(*row.model_dump().values(), style="red")
Expand Down

0 comments on commit d8dab84

Please sign in to comment.