diff --git a/CHANGELOG.md b/CHANGELOG.md index 7af2390..55fb68a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to ## [Unreleased] +- Add model filters to let models be ignored by certain rules. + ## [0.4.0] - 2024-08-08 - Add null check before calling `project_evaluated` in the `evaluate` method to diff --git a/docs/configuration.md b/docs/configuration.md index bbc051b..008dd17 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -86,6 +86,9 @@ Every rule can be configured with the following option: - `severity`: The severity of the rule. Rules have a default severity and can be overridden. It's an integer with a minimum value of 1 and a maximum value of 4. +- `model_filter_names`: Filters used by the rule. Takes a list of names that can + be found in the same namespace as the rules (see + [Package rules](package_rules.md)). Some rules have additional configuration options, e.g. [sql_has_reasonable_number_of_lines](rules/generic.md#sql_has_reasonable_number_of_lines). diff --git a/docs/create_rules.md b/docs/create_rules.md index 1efa1f8..cd5e5b4 100644 --- a/docs/create_rules.md +++ b/docs/create_rules.md @@ -90,3 +90,42 @@ def sql_has_reasonable_number_of_lines(model: Model, max_lines: int = 200) -> Ru message=f"SQL query too long: {count_lines} lines (> {max_lines})." ) ``` + +### Filtering models + +Custom and standard rules can be configured to have model filters. Filters allow +models to be ignored by one or multiple rules. + +Filters are created using the same discovery mechanism and interface as custom +rules, except they do not accept parameters. Similar to Python's built-in +`filter` function, when the filter evaluation returns `True` the model should be +evaluated, otherwise it should be ignored. + +```python +from dbt_score import ModelFilter, model_filter + +@model_filter +def only_schema_x(model: Model) -> bool: + """Only applies a rule to schema X.""" + return model.schema.lower() == 'x' + +class SkipSchemaY(ModelFilter): + description = "Applies a rule to every schema but Y." + def evaluate(self, model: Model) -> bool: + return model.schema.lower() != 'y' +``` + +Similar to setting a rule severity, standard rules can have filters set in the +[configuration file](configuration.md/#tooldbt-scorerulesrule_namespacerule_name), +while custom rules accept the configuration file or a decorator parameter. + +```python +from dbt_score import Model, rule, RuleViolation +from my_project import only_schema_x + +@rule(model_filters={only_schema_x()}) +def models_in_x_follow_naming_standard(model: Model) -> RuleViolation | None: + """Models in schema X must follow the naming standard.""" + if some_regex_fails(model.name): + return RuleViolation("Invalid model name.") +``` diff --git a/src/dbt_score/__init__.py b/src/dbt_score/__init__.py index 134923a..3f4d3b2 100644 --- a/src/dbt_score/__init__.py +++ b/src/dbt_score/__init__.py @@ -1,6 +1,15 @@ """Init dbt_score package.""" +from dbt_score.model_filter import ModelFilter, model_filter from dbt_score.models import Model from dbt_score.rule import Rule, RuleViolation, Severity, rule -__all__ = ["Model", "Rule", "RuleViolation", "Severity", "rule"] +__all__ = [ + "Model", + "ModelFilter", + "Rule", + "RuleViolation", + "Severity", + "model_filter", + "rule", +] diff --git a/src/dbt_score/evaluation.py b/src/dbt_score/evaluation.py index da35b3a..c583d06 100644 --- a/src/dbt_score/evaluation.py +++ b/src/dbt_score/evaluation.py @@ -57,11 +57,11 @@ def evaluate(self) -> None: self.results[model] = {} for rule in rules: try: - result: RuleViolation | None = rule.evaluate(model, **rule.config) + if rule.should_evaluate(model): # Consider model filter(s). + result = rule.evaluate(model, **rule.config) + self.results[model][rule.__class__] = result except Exception as e: self.results[model][rule.__class__] = e - else: - self.results[model][rule.__class__] = result self.scores[model] = self._scorer.score_model(self.results[model]) self._formatter.model_evaluated( diff --git a/src/dbt_score/model_filter.py b/src/dbt_score/model_filter.py new file mode 100644 index 0000000..102a44c --- /dev/null +++ b/src/dbt_score/model_filter.py @@ -0,0 +1,115 @@ +"""Model filtering to choose when to apply specific rules.""" + +from typing import Any, Callable, Type, TypeAlias, overload + +from dbt_score.models import Model + +FilterEvaluationType: TypeAlias = Callable[[Model], bool] + + +class ModelFilter: + """The Filter base class.""" + + description: str + + def __init__(self) -> None: + """Initialize the filter.""" + pass + + def __init_subclass__(cls, **kwargs) -> None: # type: ignore + """Initializes the subclass.""" + super().__init_subclass__(**kwargs) + if not hasattr(cls, "description"): + raise AttributeError("Subclass must define class attribute `description`.") + + def evaluate(self, model: Model) -> bool: + """Evaluates the filter.""" + raise NotImplementedError("Subclass must implement method `evaluate`.") + + @classmethod + def source(cls) -> str: + """Return the source of the filter, i.e. a fully qualified name.""" + return f"{cls.__module__}.{cls.__name__}" + + def __hash__(self) -> int: + """Compute a unique hash for a filter.""" + return hash(self.source()) + + +# Use @overload to have proper typing for both @model_filter and @model_filter(...) +# https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories + + +@overload +def model_filter(__func: FilterEvaluationType) -> Type[ModelFilter]: + ... + + +@overload +def model_filter( + *, + description: str | None = None, +) -> Callable[[FilterEvaluationType], Type[ModelFilter]]: + ... + + +def model_filter( + __func: FilterEvaluationType | None = None, + *, + description: str | None = None, +) -> Type[ModelFilter] | Callable[[FilterEvaluationType], Type[ModelFilter]]: + """Model-filter decorator. + + The model-filter decorator creates a filter class (subclass of ModelFilter) + and returns it. + + Using arguments or not are both supported: + - ``@model_filter`` + - ``@model_filter(description="...")`` + + Args: + __func: The filter evaluation function being decorated. + description: The description of the filter. + """ + + def decorator_filter( + func: FilterEvaluationType, + ) -> Type[ModelFilter]: + """Decorator function.""" + if func.__doc__ is None and description is None: + raise AttributeError( + "ModelFilter must define `description` or `func.__doc__`." + ) + + # Get description parameter, otherwise use the docstring + filter_description = description or ( + func.__doc__.split("\n")[0] if func.__doc__ else None + ) + + def wrapped_func(self: ModelFilter, *args: Any, **kwargs: Any) -> bool: + """Wrap func to add `self`.""" + return func(*args, **kwargs) + + # Create the filter class inheriting from ModelFilter + filter_class = type( + func.__name__, + (ModelFilter,), + { + "description": filter_description, + "evaluate": wrapped_func, + # Save provided evaluate function + "_orig_evaluate": func, + # Forward origin of the decorated function + "__qualname__": func.__qualname__, # https://peps.python.org/pep-3155/ + "__module__": func.__module__, + }, + ) + + return filter_class + + if __func is not None: + # The syntax @model_filter is used + return decorator_filter(__func) + else: + # The syntax @model_filter(...) is used + return decorator_filter diff --git a/src/dbt_score/rule.py b/src/dbt_score/rule.py index a78912c..3d68fdc 100644 --- a/src/dbt_score/rule.py +++ b/src/dbt_score/rule.py @@ -4,8 +4,9 @@ import typing from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Type, TypeAlias, overload +from typing import Any, Callable, Iterable, Type, TypeAlias, overload +from dbt_score.model_filter import ModelFilter from dbt_score.models import Model @@ -24,6 +25,7 @@ class RuleConfig: severity: Severity | None = None config: dict[str, Any] = field(default_factory=dict) + model_filter_names: list[str] = field(default_factory=list) @staticmethod def from_dict(rule_config: dict[str, Any]) -> "RuleConfig": @@ -34,8 +36,15 @@ def from_dict(rule_config: dict[str, Any]) -> "RuleConfig": if "severity" in rule_config else None ) + filter_names = ( + config.pop("model_filter_names", None) + if "model_filter_names" in rule_config + else [] + ) - return RuleConfig(severity=severity, config=config) + return RuleConfig( + severity=severity, config=config, model_filter_names=filter_names + ) @dataclass @@ -53,6 +62,8 @@ class Rule: description: str severity: Severity = Severity.MEDIUM + model_filter_names: list[str] + model_filters: frozenset[ModelFilter] = frozenset() default_config: typing.ClassVar[dict[str, Any]] = {} def __init__(self, rule_config: RuleConfig | None = None) -> None: @@ -83,17 +94,30 @@ def process_config(self, rule_config: RuleConfig) -> None: self.set_severity( rule_config.severity ) if rule_config.severity else rule_config.severity + self.model_filter_names = rule_config.model_filter_names self.config = config def evaluate(self, model: Model) -> RuleViolation | None: """Evaluates the rule.""" raise NotImplementedError("Subclass must implement method `evaluate`.") + @classmethod + def should_evaluate(cls, model: Model) -> bool: + """Checks if all filters in the rule allow evaluation.""" + if cls.model_filters: + return all(f.evaluate(model) for f in cls.model_filters) + return True + @classmethod def set_severity(cls, severity: Severity) -> None: """Set the severity of the rule.""" cls.severity = severity + @classmethod + def set_filters(cls, model_filters: Iterable[ModelFilter]) -> None: + """Set the filters of the rule.""" + cls.model_filters = frozenset(model_filters) + @classmethod def source(cls) -> str: """Return the source of the rule, i.e. a fully qualified name.""" @@ -118,6 +142,7 @@ def rule( *, description: str | None = None, severity: Severity = Severity.MEDIUM, + model_filters: set[ModelFilter] | None = None, ) -> Callable[[RuleEvaluationType], Type[Rule]]: ... @@ -127,6 +152,7 @@ def rule( *, description: str | None = None, severity: Severity = Severity.MEDIUM, + model_filters: set[ModelFilter] | None = None, ) -> Type[Rule] | Callable[[RuleEvaluationType], Type[Rule]]: """Rule decorator. @@ -140,6 +166,7 @@ def rule( __func: The rule evaluation function being decorated. description: The description of the rule. severity: The severity of the rule. + model_filters: Set of ModelFilter that filters the rule. """ def decorator_rule( @@ -172,6 +199,7 @@ def wrapped_func(self: Rule, *args: Any, **kwargs: Any) -> RuleViolation | None: { "description": rule_description, "severity": severity, + "model_filters": model_filters or frozenset(), "default_config": default_config, "evaluate": wrapped_func, # Save provided evaluate function diff --git a/src/dbt_score/rule_registry.py b/src/dbt_score/rule_registry.py index a8681b2..0e4557a 100644 --- a/src/dbt_score/rule_registry.py +++ b/src/dbt_score/rule_registry.py @@ -12,6 +12,7 @@ from dbt_score.config import Config from dbt_score.exceptions import DuplicatedRuleException +from dbt_score.model_filter import ModelFilter from dbt_score.rule import Rule, RuleConfig logger = logging.getLogger(__name__) @@ -24,12 +25,18 @@ def __init__(self, config: Config) -> None: """Instantiate a rule registry.""" self.config = config self._rules: dict[str, Rule] = {} + self._model_filters: dict[str, ModelFilter] = {} @property def rules(self) -> dict[str, Rule]: """Get all rules.""" return self._rules + @property + def model_filters(self) -> dict[str, ModelFilter]: + """Get all filters.""" + return self._model_filters + def _walk_packages(self, namespace_name: str) -> Iterator[str]: """Walk packages and sub-packages recursively.""" try: @@ -50,23 +57,36 @@ def _walk_packages(self, namespace_name: str) -> Iterator[str]: yield package.name def _load(self, namespace_name: str) -> None: - """Load rules found in a given namespace.""" + """Load rules and filters found in a given namespace.""" for module_name in self._walk_packages(namespace_name): module = importlib.import_module(module_name) for obj_name in dir(module): obj = module.__dict__[obj_name] if type(obj) is type and issubclass(obj, Rule) and obj is not Rule: self._add_rule(obj) + if ( + type(obj) is type + and issubclass(obj, ModelFilter) + and obj is not ModelFilter + ): + self._add_filter(obj) def _add_rule(self, rule: Type[Rule]) -> None: """Initialize and add a rule.""" rule_name = rule.source() if rule_name in self._rules: - raise DuplicatedRuleException(rule.source()) + raise DuplicatedRuleException(rule_name) if rule_name not in self.config.disabled_rules: rule_config = self.config.rules_config.get(rule_name, RuleConfig()) self._rules[rule_name] = rule(rule_config=rule_config) + def _add_filter(self, model_filter: Type[ModelFilter]) -> None: + """Initialize and add a filter.""" + filter_name = model_filter.source() + if filter_name in self._model_filters: + raise DuplicatedRuleException(filter_name) + self._model_filters[filter_name] = model_filter() + def load_all(self) -> None: """Load all rules, core and third-party.""" # Add cwd to Python path @@ -79,3 +99,21 @@ def load_all(self) -> None: # Restore original values sys.path = old_sys_path + + self._load_filters_into_rules() + + def _load_filters_into_rules(self) -> None: + """Loads ModelFilters into Rule objects. + + If the config of the rule has filter names in the `model_filter_names` key, + load those filters from the rule registry into the actual `model_filters` field. + Configuration overwrites any pre-existing filters. + """ + for rule in self._rules.values(): + filter_names: list[str] = rule.model_filter_names or [] + if len(filter_names) > 0: + rule.set_filters( + model_filter + for name, model_filter in self.model_filters.items() + if name in filter_names + ) diff --git a/src/dbt_score/scoring.py b/src/dbt_score/scoring.py index abf9363..60b05c3 100644 --- a/src/dbt_score/scoring.py +++ b/src/dbt_score/scoring.py @@ -39,7 +39,9 @@ def __init__(self, config: Config) -> None: def score_model(self, model_results: ModelResultsType) -> Score: """Compute the score of a given model.""" - if len(model_results) == 0: + rule_count = len(model_results) + + if rule_count == 0: # No rule? No problem score = self.max_score elif any( @@ -60,7 +62,7 @@ def score_model(self, model_results: ModelResultsType) -> Score: for rule, result in model_results.items() ] ) - / (self.score_cardinality * len(model_results)) + / (self.score_cardinality * rule_count) * self.max_score ) diff --git a/tests/conftest.py b/tests/conftest.py index e949327..4704d24 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from dbt_score import Model, Rule, RuleViolation, Severity, rule from dbt_score.config import Config +from dbt_score.model_filter import ModelFilter, model_filter from dbt_score.models import ManifestLoader from pytest import fixture @@ -210,3 +211,41 @@ def rule_error(model: Model) -> RuleViolation | None: raise Exception("Oh noes, something went wrong") return rule_error + + +@fixture +def rule_with_filter() -> Type[Rule]: + """An example rule that skips through a filter.""" + + @model_filter + def skip_model1(model: Model) -> bool: + """Skips for model1, passes for model2.""" + return model.name != "model1" + + @rule(model_filters={skip_model1()}) + def rule_with_filter(model: Model) -> RuleViolation | None: + """Rule that always fails when not filtered.""" + return RuleViolation(message="I always fail.") + + return rule_with_filter + + +@fixture +def class_rule_with_filter() -> Type[Rule]: + """Using class definitions for filters and rules.""" + + class SkipModel1(ModelFilter): + description = "Filter defined by a class." + + def evaluate(self, model: Model) -> bool: + """Skips for model1, passes for model2.""" + return model.name != "model1" + + class RuleWithFilter(Rule): + description = "Filter defined by a class." + model_filters = frozenset({SkipModel1()}) + + def evaluate(self, model: Model) -> RuleViolation | None: + return RuleViolation(message="I always fail.") + + return RuleWithFilter diff --git a/tests/formatters/test_ascii_formatter.py b/tests/formatters/test_ascii_formatter.py index f76cabb..2ab4b41 100644 --- a/tests/formatters/test_ascii_formatter.py +++ b/tests/formatters/test_ascii_formatter.py @@ -1,9 +1,9 @@ """Unit tests for the ASCII formatter.""" -from typing import Type +from dbt_score.evaluation import ModelResultsType from dbt_score.formatters.ascii_formatter import ASCIIFormatter -from dbt_score.rule import Rule, RuleViolation +from dbt_score.rule import RuleViolation from dbt_score.scoring import Score @@ -18,7 +18,7 @@ def test_ascii_formatter_model( ): """Ensure the formatter doesn't write anything after model evaluation.""" formatter = ASCIIFormatter(manifest_loader=manifest_loader, config=default_config) - results: dict[Type[Rule], RuleViolation | Exception | None] = { + results: ModelResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), diff --git a/tests/formatters/test_human_readable_formatter.py b/tests/formatters/test_human_readable_formatter.py index 295e804..b4afeb1 100644 --- a/tests/formatters/test_human_readable_formatter.py +++ b/tests/formatters/test_human_readable_formatter.py @@ -1,9 +1,9 @@ """Unit tests for the human readable formatter.""" -from typing import Type +from dbt_score.evaluation import ModelResultsType from dbt_score.formatters.human_readable_formatter import HumanReadableFormatter -from dbt_score.rule import Rule, RuleViolation +from dbt_score.rule import RuleViolation from dbt_score.scoring import Score @@ -20,7 +20,7 @@ def test_human_readable_formatter_model( formatter = HumanReadableFormatter( manifest_loader=manifest_loader, config=default_config ) - results: dict[Type[Rule], RuleViolation | Exception | None] = { + results: ModelResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), @@ -59,7 +59,7 @@ def test_human_readable_formatter_low_model_score( formatter = HumanReadableFormatter( manifest_loader=manifest_loader, config=default_config ) - results: dict[Type[Rule], RuleViolation | Exception | None] = { + results: ModelResultsType = { rule_severity_critical: RuleViolation("Error"), } formatter.model_evaluated(model1, results, Score(0.0, "🚧")) @@ -90,7 +90,7 @@ def test_human_readable_formatter_low_project_score( formatter = HumanReadableFormatter( manifest_loader=manifest_loader, config=default_config ) - results: dict[Type[Rule], RuleViolation | Exception | None] = { + results: ModelResultsType = { rule_severity_critical: RuleViolation("Error"), } formatter.model_evaluated(model1, results, Score(10.0, "🥇")) diff --git a/tests/formatters/test_manifest_formatter.py b/tests/formatters/test_manifest_formatter.py index 56c7daa..baaacc8 100644 --- a/tests/formatters/test_manifest_formatter.py +++ b/tests/formatters/test_manifest_formatter.py @@ -1,10 +1,10 @@ """Unit tests for the manifest formatter.""" import json -from typing import Type +from dbt_score.evaluation import ModelResultsType from dbt_score.formatters.manifest_formatter import ManifestFormatter -from dbt_score.rule import Rule, RuleViolation +from dbt_score.rule import RuleViolation from dbt_score.scoring import Score @@ -21,7 +21,7 @@ def test_manifest_formatter_model( formatter = ManifestFormatter( manifest_loader=manifest_loader, config=default_config ) - results = { + results: ModelResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), @@ -45,12 +45,12 @@ def test_manifest_formatter_project( formatter = ManifestFormatter( manifest_loader=manifest_loader, config=default_config ) - result1: dict[Type[Rule], RuleViolation | Exception | None] = { + result1: ModelResultsType = { rule_severity_low: None, rule_severity_medium: Exception("Oh noes"), rule_severity_critical: RuleViolation("Error"), } - result2: dict[Type[Rule], RuleViolation | Exception | None] = { + result2: ModelResultsType = { rule_severity_low: None, rule_severity_medium: None, rule_severity_critical: None, diff --git a/tests/resources/pyproject.toml b/tests/resources/pyproject.toml index eb598f3..cdd6cdb 100644 --- a/tests/resources/pyproject.toml +++ b/tests/resources/pyproject.toml @@ -25,3 +25,4 @@ model_name="model2" [tool.dbt-score.rules."tests.rules.example.rule_test_example"] severity=4 +model_filter_names=["tests.rules.example.skip_model1"] diff --git a/tests/rules/example.py b/tests/rules/example.py index cadfea2..7c383ac 100644 --- a/tests/rules/example.py +++ b/tests/rules/example.py @@ -1,8 +1,14 @@ """Example rules.""" -from dbt_score import Model, RuleViolation, rule +from dbt_score import Model, RuleViolation, model_filter, rule @rule() def rule_test_example(model: Model) -> RuleViolation | None: """An example rule.""" + + +@model_filter +def skip_model1(model: Model) -> bool: + """An example filter.""" + return model.name != "model1" diff --git a/tests/test_config.py b/tests/test_config.py index 14c8392..634e99a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -27,6 +27,9 @@ def test_load_valid_toml_file(valid_config_path): assert config.badge_config.first.icon == "1️⃣" assert config.fail_project_under == 7.5 assert config.fail_any_model_under == 6.9 + assert config.rules_config[ + "tests.rules.example.rule_test_example" + ].model_filter_names == ["tests.rules.example.skip_model1"] def test_load_invalid_toml_file(caplog, invalid_config_path): diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index a7e7a4e..569124b 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -179,3 +179,59 @@ def test_evaluation_rule_with_config( ) assert evaluation.results[model1][rule_with_config] is not None assert evaluation.results[model2][rule_with_config] is None + + +def test_evaluation_with_filter(manifest_path, default_config, rule_with_filter): + """Test rule with filter.""" + manifest_loader = ManifestLoader(manifest_path) + mock_formatter = Mock() + mock_scorer = Mock() + + rule_registry = RuleRegistry(default_config) + rule_registry._add_rule(rule_with_filter) + + # Ensure we get a valid Score object from the Mock + mock_scorer.score_model.return_value = Score(10, "🥇") + + evaluation = Evaluation( + rule_registry=rule_registry, + manifest_loader=manifest_loader, + formatter=mock_formatter, + scorer=mock_scorer, + ) + evaluation.evaluate() + + model1 = manifest_loader.models[0] + model2 = manifest_loader.models[1] + + assert rule_with_filter not in evaluation.results[model1] + assert isinstance(evaluation.results[model2][rule_with_filter], RuleViolation) + + +def test_evaluation_with_class_filter( + manifest_path, default_config, class_rule_with_filter +): + """Test rule with filters and filtered rules defined by classes.""" + manifest_loader = ManifestLoader(manifest_path) + mock_formatter = Mock() + mock_scorer = Mock() + + rule_registry = RuleRegistry(default_config) + rule_registry._add_rule(class_rule_with_filter) + + # Ensure we get a valid Score object from the Mock + mock_scorer.score_model.return_value = Score(10, "🥇") + + evaluation = Evaluation( + rule_registry=rule_registry, + manifest_loader=manifest_loader, + formatter=mock_formatter, + scorer=mock_scorer, + ) + evaluation.evaluate() + + model1 = manifest_loader.models[0] + model2 = manifest_loader.models[1] + + assert class_rule_with_filter not in evaluation.results[model1] + assert isinstance(evaluation.results[model2][class_rule_with_filter], RuleViolation) diff --git a/tests/test_model_filter.py b/tests/test_model_filter.py new file mode 100644 index 0000000..86f1bf1 --- /dev/null +++ b/tests/test_model_filter.py @@ -0,0 +1,33 @@ +"""Test model filters.""" + +from dbt_score.model_filter import ModelFilter, model_filter +from dbt_score.models import Model + + +def test_basic_filter(model1, model2): + """Test basic filter testing for a specific model.""" + + @model_filter + def only_model1(model: Model) -> bool: + """Some description.""" + return model.name == "model1" + + instance = only_model1() # since the decorator returns a Type + assert instance.description == "Some description." + assert instance.evaluate(model1) + assert not instance.evaluate(model2) + + +def test_class_filter(model1, model2): + """Test basic filter using class.""" + + class OnlyModel1(ModelFilter): + description = "Some description." + + def evaluate(self, model: Model) -> bool: + return model.name == "model1" + + instance = OnlyModel1() + assert instance.description == "Some description." + assert instance.evaluate(model1) + assert not instance.evaluate(model2) diff --git a/tests/test_rule_registry.py b/tests/test_rule_registry.py index aeef2da..4e7abb2 100644 --- a/tests/test_rule_registry.py +++ b/tests/test_rule_registry.py @@ -15,6 +15,7 @@ def test_rule_registry_discovery(default_config): "tests.rules.example.rule_test_example", "tests.rules.nested.example.rule_test_nested_example", ] + assert list(r._model_filters.keys()) == ["tests.rules.example.skip_model1"] def test_disabled_rule_registry_discovery(): @@ -52,3 +53,15 @@ def test_rule_registry_core_rules(default_config): r = RuleRegistry(default_config) r.load_all() assert len(r.rules) > 0 + + +def test_rule_registry_model_filters(valid_config_path, model1, model2): + """Test config filters are loaded.""" + config = Config() + config._load_toml_file(str(valid_config_path)) + r = RuleRegistry(config) + r._load("tests.rules") + r._load_filters_into_rules() + + assert not r.rules["tests.rules.example.rule_test_example"].should_evaluate(model1) + assert r.rules["tests.rules.example.rule_test_example"].should_evaluate(model2)