Skip to content

Commit

Permalink
Addresses all comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
theyostalservice committed Oct 31, 2024
1 parent ad329c5 commit b33ac7e
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 110 deletions.
70 changes: 48 additions & 22 deletions dbt_semantic_interfaces/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity
from dbt_semantic_interfaces.validations.validator_helpers import (
SemanticManifestValidationResults,
ValidationIssue,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -174,37 +175,62 @@ def semantic_model_with_guaranteed_meta(
)


def check_only_one_error_with_message( # noqa: D
results: SemanticManifestValidationResults, target_message: str
def _assert_expected_validation_message( # noqa: D
issues: Sequence[ValidationIssue],
message_fragment: str,
) -> None:
assert len(results.warnings) == 0
assert len(results.errors) == 1
assert len(results.future_errors) == 0

found_match = results.errors[0].message.find(target_message) != -1
found_match = any([issue.message.find(message_fragment) != -1 for issue in issues])
# Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values.
assert {
"expected": target_message,
"actual": results.errors[0].message,
"expected": message_fragment,
"actual_messages": [issue.message for issue in issues],
} and found_match


def check_only_one_warning_with_message( # noqa: D
def check_expected_issues( # noqa: D
results: SemanticManifestValidationResults,
num_expected_errors: int = 0,
num_expected_warnings: int = 0,
expected_error_msgs: Sequence[str] = [],
expected_warning_msgs: Sequence[str] = [],
) -> None:
"""Validates the number, type, and content of ValidationIssues.
Currently assumes zero future_errors as there are no future_errors
implemented, but this function can be expanded to cover those if needed.
"""
assert len(results.warnings) == num_expected_warnings
assert len(results.errors) == num_expected_errors
assert len(results.future_errors) == 0, "validation function expects zero future_errors to be implemented."

for expected_error_msg in expected_error_msgs:
_assert_expected_validation_message(issues=results.errors, message_fragment=expected_error_msg)
for expected_warning_msg in expected_warning_msgs:
_assert_expected_validation_message(issues=results.warnings, message_fragment=expected_warning_msg)


def check_only_one_error_with_message( # noqa: D
results: SemanticManifestValidationResults, target_message: str
) -> None:
assert len(results.errors) == 0
assert len(results.warnings) == 1
assert len(results.future_errors) == 0
check_expected_issues(
results=results,
num_expected_errors=1,
expected_error_msgs=[target_message],
)

found_match = results.warnings[0].message.find(target_message) != -1
# Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values.
assert {
"expected": target_message,
"actual": results.warnings[0].message,
} and found_match

def check_only_one_warning_with_message( # noqa: D
results: SemanticManifestValidationResults, target_message: str
) -> None:
check_expected_issues(
results=results,
num_expected_warnings=1,
expected_warning_msgs=[target_message],
)


def check_no_errors_or_warnings(results: SemanticManifestValidationResults) -> None: # noqa: D
assert len(results.errors) == 0
assert len(results.warnings) == 0
assert len(results.future_errors) == 0
# no num arguments required since all defaults are zero
check_expected_issues(
results=results,
)
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@
SemanticManifestValidationResults,
SemanticManifestValidationRule,
)
from dbt_semantic_interfaces.validations.where_filters import (
WhereFiltersAreParseableRule,
)
from dbt_semantic_interfaces.validations.where_filters import WhereFiltersAreParseable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,7 +86,7 @@ class SemanticManifestValidator(Generic[SemanticManifestT]):
SemanticModelDefaultsRule[SemanticManifestT](),
PrimaryEntityRule[SemanticManifestT](),
PrimaryEntityDimensionPairs[SemanticManifestT](),
WhereFiltersAreParseableRule[SemanticManifestT](),
WhereFiltersAreParseable[SemanticManifestT](),
SavedQueryRule[SemanticManifestT](),
MetricLabelsRule[SemanticManifestT](),
SemanticModelLabelsRule[SemanticManifestT](),
Expand Down
144 changes: 75 additions & 69 deletions dbt_semantic_interfaces/validations/where_filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import traceback
from enum import StrEnum, auto
from typing import Generic, List, Sequence, Tuple

from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets
Expand All @@ -20,83 +21,84 @@
)


class WhereFiltersAreParseableRule(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]):
"""Validates that all Metric WhereFilters are parseable."""
class SemanticManifestNodeType(StrEnum):
"""Types of objects to validate (used for validation messages)."""

SAVED_QUERY = auto()
METRIC = auto()


class WhereFiltersAreParseable(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]):
"""Validates that all WhereFilters are parseable."""

@staticmethod
def _validate_time_granularity_names_impl(
location_label_for_errors: str,
filter_call_parameter_sets_with_context: Sequence[Tuple[ValidationContext, FilterCallParameterSets]],
custom_granularity_names: List[str],
def _validate_time_granularity_names(
element_name: str,
object_type: SemanticManifestNodeType,
context: ValidationContext,
filter_call_param_sets: FilterCallParameterSets,
valid_granularity_names: List[str],
) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []

valid_granularity_names = [
standard_granularity.value for standard_granularity in TimeGranularity
] + custom_granularity_names

for context, parameter_set in filter_call_parameter_sets_with_context:
for time_dim_call_parameter_set in parameter_set.time_dimension_call_parameter_sets:
if not time_dim_call_parameter_set.time_granularity_name:
continue
if time_dim_call_parameter_set.time_granularity_name.lower() not in valid_granularity_names:
issues.append(
ValidationWarning(
context=context,
message=f"Filter for `{location_label_for_errors}` is not valid. "
f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. "
f"Valid granularity options: {valid_granularity_names}",
)
for time_dim_call_parameter_set in filter_call_param_sets.time_dimension_call_parameter_sets:
if not time_dim_call_parameter_set.time_granularity_name:
continue
if time_dim_call_parameter_set.time_granularity_name.lower() not in valid_granularity_names:
issues.append(
ValidationWarning(
context=context,
message=f"Filter for {object_type} `{element_name}` is not valid. "
f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. "
f"Valid granularity options: {valid_granularity_names}",
)
)
return issues

@staticmethod
def _validate_time_granularity_names_for_saved_query(
saved_query: SavedQuery, custom_granularity_names: List[str]
saved_query: SavedQuery, valid_granularity_names: List[str]
) -> Sequence[ValidationIssue]:
where_param = saved_query.query_params.where
if where_param is None:
return []

return WhereFiltersAreParseableRule._validate_time_granularity_names_impl(
location_label_for_errors="saved query `{saved_query.name}`",
filter_call_parameter_sets_with_context=[
(
SavedQueryContext(
file_context=FileContext.from_metadata(metadata=saved_query.metadata),
element_type=SavedQueryElementType.WHERE,
element_value=where_filter.where_sql_template,
),
where_filter.call_parameter_sets,
)
for where_filter in where_param.where_filters
],
custom_granularity_names=custom_granularity_names,
)
issues: List[ValidationIssue] = []
for where_filter in where_param.where_filters:
issues += WhereFiltersAreParseable._validate_time_granularity_names(
element_name=saved_query.name,
object_type=SemanticManifestNodeType.SAVED_QUERY,
context=SavedQueryContext(
file_context=FileContext.from_metadata(metadata=saved_query.metadata),
element_type=SavedQueryElementType.WHERE,
element_value=where_filter.where_sql_template,
),
filter_call_param_sets=where_filter.call_parameter_sets,
valid_granularity_names=valid_granularity_names,
)

return issues

@staticmethod
def _validate_time_granularity_names_for_metric(
context: MetricContext,
filter_expression_parameter_sets: Sequence[Tuple[str, FilterCallParameterSets]],
custom_granularity_names: List[str],
valid_granularity_names: List[str],
) -> Sequence[ValidationIssue]:
return WhereFiltersAreParseableRule._validate_time_granularity_names_impl(
location_label_for_errors="metric `{context.metric.metric_name}`",
filter_call_parameter_sets_with_context=[
(
context,
param_set[1],
)
for param_set in filter_expression_parameter_sets
],
custom_granularity_names=custom_granularity_names,
)
issues: List[ValidationIssue] = []
for _, param_set in filter_expression_parameter_sets:
issues += WhereFiltersAreParseable._validate_time_granularity_names(
element_name=context.metric.metric_name,
object_type=SemanticManifestNodeType.METRIC,
context=context,
filter_call_param_sets=param_set,
valid_granularity_names=valid_granularity_names,
)
return issues

@staticmethod
@validate_safely("validating the where field in a saved query.")
def _validate_saved_query(
saved_query: SavedQuery, custom_granularity_names: List[str]
) -> Sequence[ValidationIssue]:
def _validate_saved_query(saved_query: SavedQuery, valid_granularity_names: List[str]) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []
if saved_query.query_params.where is None:
return issues
Expand All @@ -119,8 +121,8 @@ def _validate_saved_query(
)
)
else:
issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_saved_query(
saved_query, custom_granularity_names
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_saved_query(
saved_query, valid_granularity_names
)

return issues
Expand All @@ -129,7 +131,7 @@ def _validate_saved_query(
@validate_safely(
whats_being_done="running model validation ensuring a metric's filter properties are configured properly"
)
def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Sequence[ValidationIssue]: # noqa: D
def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequence[ValidationIssue]: # noqa: D
issues: List[ValidationIssue] = []
context = MetricContext(
file_context=FileContext.from_metadata(metadata=metric.metadata),
Expand All @@ -151,10 +153,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq
)
)
else:
issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric(
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets,
custom_granularity_names=custom_granularity_names,
valid_granularity_names=valid_granularity_names,
)

if metric.type_params:
Expand All @@ -175,10 +177,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq
)
)
else:
issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric(
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets,
custom_granularity_names=custom_granularity_names,
valid_granularity_names=valid_granularity_names,
)

numerator = metric.type_params.numerator
Expand All @@ -197,10 +199,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq
)
)
else:
issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric(
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets,
custom_granularity_names=custom_granularity_names,
valid_granularity_names=valid_granularity_names,
)

denominator = metric.type_params.denominator
Expand All @@ -219,10 +221,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq
)
)
else:
issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric(
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets,
custom_granularity_names=custom_granularity_names,
valid_granularity_names=valid_granularity_names,
)

for input_metric in metric.type_params.metrics or []:
Expand All @@ -242,10 +244,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq
)
)
else:
issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric(
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets,
custom_granularity_names=custom_granularity_names,
valid_granularity_names=valid_granularity_names,
)
return issues

Expand All @@ -258,11 +260,15 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati
for time_spine in semantic_manifest.project_configuration.time_spines
for granularity in time_spine.custom_granularities
]
valid_granularity_names = [
standard_granularity.value for standard_granularity in TimeGranularity
] + custom_granularity_names

for metric in semantic_manifest.metrics or []:
issues += WhereFiltersAreParseableRule._validate_metric(
metric=metric, custom_granularity_names=custom_granularity_names
issues += WhereFiltersAreParseable._validate_metric(
metric=metric, valid_granularity_names=valid_granularity_names
)
for saved_query in semantic_manifest.saved_queries:
issues += WhereFiltersAreParseableRule._validate_saved_query(saved_query, custom_granularity_names)
issues += WhereFiltersAreParseable._validate_saved_query(saved_query, valid_granularity_names)

return issues
Loading

0 comments on commit b33ac7e

Please sign in to comment.