Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize validations for custom granularities #370

Merged
merged 8 commits into from
Dec 2, 2024
50 changes: 26 additions & 24 deletions dbt_semantic_interfaces/implementations/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Dict, List, Optional, Sequence, Set
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Set

from typing_extensions import override

Expand Down Expand Up @@ -83,7 +84,7 @@ def _from_yaml_value(cls, input: PydanticParseableValueType) -> PydanticMetricTi
The MetricTimeWindow is always expected to be provided as a string in user-defined YAML configs.
"""
if isinstance(input, str):
return PydanticMetricTimeWindow.parse(window=input.lower(), custom_granularity_names=(), strict=False)
return PydanticMetricTimeWindow.parse(window=input.lower())
else:
raise ValueError(
f"MetricTimeWindow inputs from model configs are expected to always be of type string, but got "
Expand All @@ -101,12 +102,8 @@ def window_string(self) -> str:
return f"{self.count} {self.granularity}"

@staticmethod
def parse(window: str, custom_granularity_names: Sequence[str], strict: bool = True) -> PydanticMetricTimeWindow:
"""Returns window values if parsing succeeds, None otherwise.

If strict=True, then the granularity in the window must exist as a valid granularity.
Use strict=True for when you have all valid granularities, otherwise use strict=False.
"""
def parse(window: str) -> PydanticMetricTimeWindow:
"""Returns window values if parsing succeeds, None otherwise."""
parts = window.lower().split(" ")
if len(parts) != 2:
raise ParsingException(
Expand All @@ -115,22 +112,6 @@ def parse(window: str, custom_granularity_names: Sequence[str], strict: bool = T
)

granularity = parts[1]

valid_time_granularities = {item.value.lower() for item in TimeGranularity} | set(
c.lower() for c in custom_granularity_names
)

# if we switched to python 3.9 this could just be `granularity = parts[0].removesuffix('s')
if granularity.endswith("s") and granularity[:-1] in valid_time_granularities:
# months -> month
granularity = granularity[:-1]

if strict and granularity not in valid_time_granularities:
raise ParsingException(
f"Invalid time granularity {granularity} in metric window string: ({window})",
)
# If not strict and not standard granularity, it may be a custom grain, so validations happens later

count = parts[0]
if not count.isdigit():
raise ParsingException(f"Invalid count ({count}) in cumulative metric window string: ({window})")
Expand Down Expand Up @@ -222,6 +203,27 @@ def _implements_protocol(self) -> Metric: # noqa: D
config: Optional[PydanticSemanticLayerElementConfig]
time_granularity: Optional[str] = None

@classmethod
def parse_obj(cls, input: Any) -> PydanticMetric:
"""Adds custom parsing to the default method."""
data = deepcopy(input)

# Ensure grain_to_date is lowercased
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we lowercase window here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but I think the way it's implemented now is less hacky! This one feels less optimal because we have to go through these dict keys that are not typed. But I couldn't use the same strategy as we did for window because grain is just a string, not an object.

type_params = data.get("type_params", {})
grain_to_date = type_params.get("cumulative_type_params", {}).get("grain_to_date")
if isinstance(grain_to_date, str):
data["type_params"]["cumulative_type_params"]["grain_to_date"] = grain_to_date.lower()

# Ensure offset_to_grain is lowercased
input_metrics = type_params.get("metrics", [])
if input_metrics:
for input_metric in input_metrics:
offset_to_grain = input_metric.get("offset_to_grain")
if offset_to_grain and isinstance(offset_to_grain, str):
input_metric["offset_to_grain"] = offset_to_grain.lower()

return super(HashableBaseModel, cls).parse_obj(data)

@property
def input_measures(self) -> Sequence[PydanticMetricInputMeasure]:
"""Return the complete list of input measure configurations for this metric."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Set

from typing_extensions import override

Expand All @@ -12,7 +12,7 @@
from dbt_semantic_interfaces.transformations.transform_rule import (
SemanticManifestTransformRule,
)
from dbt_semantic_interfaces.type_enums import MetricType
from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity


class RemovePluralFromWindowGranularityRule(ProtocolHint[SemanticManifestTransformRule[PydanticSemanticManifest]]):
Expand All @@ -30,15 +30,21 @@ def _implements_protocol(self) -> SemanticManifestTransformRule[PydanticSemantic

@staticmethod
def _update_metric(
semantic_manifest: PydanticSemanticManifest, metric_name: str, custom_granularity_names: Sequence[str]
semantic_manifest: PydanticSemanticManifest, metric_name: str, custom_granularity_names: Set[str]
) -> None:
"""Mutates all the MetricTimeWindow by reparsing to remove the trailing 's'."""
valid_time_granularities = {item.value.lower() for item in TimeGranularity} | set(
c.lower() for c in custom_granularity_names
)

def reparse_window(window: PydanticMetricTimeWindow) -> PydanticMetricTimeWindow:
def trim_trailing_s(window: PydanticMetricTimeWindow) -> PydanticMetricTimeWindow:
"""Reparse the window to remove the trailing 's'."""
return PydanticMetricTimeWindow.parse(
window=window.window_string, custom_granularity_names=custom_granularity_names
)
granularity = window.granularity
if granularity.endswith("s") and granularity[:-1] in valid_time_granularities:
# months -> month
granularity = granularity[:-1]
window.granularity = granularity
return window

matched_metric = next(
iter((metric for metric in semantic_manifest.metrics if metric.name == metric_name)), None
Expand All @@ -49,22 +55,23 @@ def reparse_window(window: PydanticMetricTimeWindow) -> PydanticMetricTimeWindow
matched_metric.type_params.cumulative_type_params
and matched_metric.type_params.cumulative_type_params.window
):
matched_metric.type_params.cumulative_type_params.window = reparse_window(
matched_metric.type_params.cumulative_type_params.window = trim_trailing_s(
matched_metric.type_params.cumulative_type_params.window
)

elif matched_metric.type is MetricType.CONVERSION:
if (
matched_metric.type_params.conversion_type_params
and matched_metric.type_params.conversion_type_params.window
):
matched_metric.type_params.conversion_type_params.window = reparse_window(
matched_metric.type_params.conversion_type_params.window = trim_trailing_s(
matched_metric.type_params.conversion_type_params.window
)

elif matched_metric.type is MetricType.DERIVED or matched_metric.type is MetricType.RATIO:
for input_metric in matched_metric.input_metrics:
if input_metric.offset_window:
input_metric.offset_window = reparse_window(input_metric.offset_window)
input_metric.offset_window = trim_trailing_s(input_metric.offset_window)
elif matched_metric.type is MetricType.SIMPLE:
pass
else:
Expand All @@ -74,11 +81,11 @@ def reparse_window(window: PydanticMetricTimeWindow) -> PydanticMetricTimeWindow

@staticmethod
def transform_model(semantic_manifest: PydanticSemanticManifest) -> PydanticSemanticManifest: # noqa: D
custom_granularity_names = [
custom_granularity_names = {
granularity.name
for time_spine in semantic_manifest.project_configuration.time_spines
for granularity in time_spine.custom_granularities
]
}

for metric in semantic_manifest.metrics:
RemovePluralFromWindowGranularityRule._update_metric(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati

@staticmethod
@validate_safely(whats_being_done="checking aggregation time dimension for a semantic model")
def _validate_semantic_model(semantic_model: SemanticModel) -> List[ValidationIssue]:
def _validate_semantic_model(semantic_model: SemanticModel) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []

for measure in semantic_model.measures:
Expand Down
14 changes: 8 additions & 6 deletions dbt_semantic_interfaces/validations/common_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _check_entity(
entity: Entity,
semantic_model: SemanticModel,
entities_to_semantic_models: Dict[EntityReference, Set[str]],
) -> List[ValidationIssue]:
) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []
# If the entity is the dict and if the set of semantic models minus this semantic model is empty,
# then we warn the user that their entity will be unused in joins
Expand Down Expand Up @@ -65,15 +65,17 @@ def _check_entity(
@validate_safely(whats_being_done="running model validation warning if entities are only one one semantic model")
def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]:
"""Issues a warning for any entity that is associated with only one semantic_model."""
issues = []
issues: List[ValidationIssue] = []

entities_to_semantic_models = CommonEntitysRule._map_semantic_model_entities(semantic_manifest.semantic_models)
for semantic_model in semantic_manifest.semantic_models or []:
for entity in semantic_model.entities or []:
issues += CommonEntitysRule._check_entity(
entity=entity,
semantic_model=semantic_model,
entities_to_semantic_models=entities_to_semantic_models,
issues.extend(
CommonEntitysRule._check_entity(
entity=entity,
semantic_model=semantic_model,
entities_to_semantic_models=entities_to_semantic_models,
)
)

return issues
4 changes: 2 additions & 2 deletions dbt_semantic_interfaces/validations/dimension_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _validate_dimension(
dimension: Dimension,
time_dims_to_granularity: Dict[DimensionReference, TimeGranularity],
semantic_model: SemanticModel,
) -> List[ValidationIssue]:
) -> Sequence[ValidationIssue]:
"""Check that time dimensions of the same name and aren't primary have the same time granularity.

Args:
Expand Down Expand Up @@ -104,7 +104,7 @@ def _validate_semantic_model(
semantic_model: SemanticModel,
dimension_to_invariant: Dict[DimensionReference, DimensionInvariants],
update_invariant_dict: bool,
) -> List[ValidationIssue]:
) -> Sequence[ValidationIssue]:
"""Checks that the given semantic model has dimensions consistent with the given invariants.

Args:
Expand Down
7 changes: 2 additions & 5 deletions dbt_semantic_interfaces/validations/element_const.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from collections import defaultdict
from typing import DefaultDict, Generic, List, Sequence

from dbt_semantic_interfaces.implementations.semantic_manifest import (
PydanticSemanticManifest,
)
from dbt_semantic_interfaces.protocols import SemanticManifestT
from dbt_semantic_interfaces.references import SemanticModelReference
from dbt_semantic_interfaces.validations.validator_helpers import (
Expand All @@ -28,7 +25,7 @@ class ElementConsistencyRule(SemanticManifestValidationRule[SemanticManifestT],

@staticmethod
@validate_safely(whats_being_done="running model validation ensuring model wide element consistency")
def validate_manifest(semantic_manifest: PydanticSemanticManifest) -> Sequence[ValidationIssue]: # noqa: D
def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3 this

issues = []
element_name_to_types = ElementConsistencyRule._get_element_name_to_types(semantic_manifest=semantic_manifest)
invalid_elements = {
Expand All @@ -54,7 +51,7 @@ def validate_manifest(semantic_manifest: PydanticSemanticManifest) -> Sequence[V

@staticmethod
def _get_element_name_to_types(
semantic_manifest: PydanticSemanticManifest,
semantic_manifest: SemanticManifestT,
) -> DefaultDict[str, DefaultDict[SemanticModelElementType, List[SemanticModelContext]]]:
"""Create a mapping of element names in the semantic manifest to types with a list of associated contexts."""
element_types: DefaultDict[
Expand Down
2 changes: 1 addition & 1 deletion dbt_semantic_interfaces/validations/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class NaturalEntityConfigurationRule(SemanticManifestValidationRule[SemanticMani
"natural entities are used in the appropriate contexts"
)
)
def _validate_semantic_model_natural_entities(semantic_model: SemanticModel) -> List[ValidationIssue]:
def _validate_semantic_model_natural_entities(semantic_model: SemanticModel) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []
context = SemanticModelContext(
file_context=FileContext.from_metadata(metadata=semantic_model.metadata),
Expand Down
4 changes: 2 additions & 2 deletions dbt_semantic_interfaces/validations/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class MeasureConstraintAliasesRule(SemanticManifestValidationRule[SemanticManife

@staticmethod
@validate_safely(whats_being_done="ensuring measures aliases are set when required")
def _validate_required_aliases_are_set(metric: Metric, metric_context: MetricContext) -> List[ValidationIssue]:
def _validate_required_aliases_are_set(metric: Metric, metric_context: MetricContext) -> Sequence[ValidationIssue]:
"""Checks if valid aliases are set on the input measure references where they are required.

Aliases are required whenever there are 2 or more input measures with the same measure
Expand Down Expand Up @@ -188,7 +188,7 @@ class MetricMeasuresRule(SemanticManifestValidationRule[SemanticManifestT], Gene

@staticmethod
@validate_safely(whats_being_done="checking all measures referenced by the metric exist")
def _validate_metric_measure_references(metric: Metric, valid_measure_names: Set[str]) -> List[ValidationIssue]:
def _validate_metric_measure_references(metric: Metric, valid_measure_names: Set[str]) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []

for measure_reference in metric.measure_references:
Expand Down
Loading
Loading