From 3baf371735a6f1925b42cb73f58fd50d7df0445e Mon Sep 17 00:00:00 2001 From: Patrick Yost Date: Wed, 23 Oct 2024 17:26:44 -0700 Subject: [PATCH 1/8] Validate time spines in saved query where filters Validations for time spines in WHERE filters in Saved Queries. This mimics the where filter time spine validation for metrics and applies it to saved queries. This also bumps the version. (I assume this is necessary?) --- .../validations/saved_query.py | 41 +++++++++++++++++- tests/validations/test_saved_query.py | 43 +++++++++++++++++++ 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index 7d3716bb..fe23c68d 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -23,6 +23,7 @@ ) from dbt_semantic_interfaces.protocols import SemanticManifestT from dbt_semantic_interfaces.protocols.saved_query import SavedQuery +from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from dbt_semantic_interfaces.validations.validator_helpers import ( FileContext, SavedQueryContext, @@ -30,6 +31,7 @@ SemanticManifestValidationRule, ValidationError, ValidationIssue, + ValidationWarning, generate_exception_issue, validate_safely, ) @@ -114,7 +116,7 @@ def _check_metrics(valid_metric_names: Set[str], saved_query: SavedQuery) -> Seq @staticmethod @validate_safely("Validate the where field in a saved query.") - def _check_where(saved_query: SavedQuery) -> Sequence[ValidationIssue]: + def _check_where(saved_query: SavedQuery, custom_granularity_names: list[str]) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] if saved_query.query_params.where is None: return issues @@ -136,9 +138,39 @@ def _check_where(saved_query: SavedQuery) -> Sequence[ValidationIssue]: }, ) ) + else: + issues += SavedQueryRule._check_where_timespine(saved_query, custom_granularity_names) return issues + def _check_where_timespine( + saved_query: SavedQuery, custom_granularity_names: list[str] + ) -> Sequence[ValidationIssue]: + issues: List[ValidationIssue] = [] + + valid_granularity_names = [ + standard_granularity.name for standard_granularity in TimeGranularity + ] + custom_granularity_names + for where_filter in saved_query.query_params.where.where_filters: + for time_dim_call_parameter_set in where_filter.call_parameter_sets.time_dimension_call_parameter_sets: + if not time_dim_call_parameter_set.time_granularity_name: + continue + if time_dim_call_parameter_set.time_granularity_name not in valid_granularity_names: + issues.append( + ValidationWarning( + context=SavedQueryContext( + file_context=FileContext.from_metadata(metadata=saved_query.metadata), + element_type=SavedQueryElementType.WHERE, + element_value=where_filter.where_sql_template, + ), + # message=f"Filter for metric `{context.metric.metric_name}` is not valid. " + message=f"Filter for saved query `{saved_query.name}` is not valid. " + f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. " + f"Valid granularity options: {valid_granularity_names}", + ) + ) + return issues + @staticmethod def _parse_query_item( saved_query: SavedQuery, @@ -280,6 +312,11 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati for entity in semantic_model.entities: valid_group_by_element_names.add(entity.name) + custom_granularity_names = [ + granularity.name + for time_spine in semantic_manifest.project_configuration.time_spines + for granularity in time_spine.custom_granularities + ] for saved_query in semantic_manifest.saved_queries: issues += SavedQueryRule._check_metrics( valid_metric_names=valid_metric_names, @@ -289,7 +326,7 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati valid_group_by_element_names=valid_group_by_element_names, saved_query=saved_query, ) - issues += SavedQueryRule._check_where(saved_query) + issues += SavedQueryRule._check_where(saved_query, custom_granularity_names) issues += SavedQueryRule._check_order_by(saved_query) issues += SavedQueryRule._check_limit(saved_query) return issues diff --git a/tests/validations/test_saved_query.py b/tests/validations/test_saved_query.py index e07288e5..56f50dea 100644 --- a/tests/validations/test_saved_query.py +++ b/tests/validations/test_saved_query.py @@ -38,6 +38,21 @@ def check_only_one_error_with_message( # noqa: D } and found_match +def check_only_one_warning_with_message( # noqa: D + results: SemanticManifestValidationResults, target_message: str +) -> None: + assert len(results.warnings) == 1 + assert len(results.errors) == 0 + assert len(results.future_errors) == 0 + + found_match = results.warnings[0].message.find(target_message) != -1 + # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. + assert { + "expected": target_message, + "actual": results.warnings[0].message, + } and found_match + + def test_invalid_metric_in_saved_query( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: @@ -87,6 +102,34 @@ def test_invalid_where_in_saved_query( # noqa: D ) +def test_where_filter_validations_invalid_granularity( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), + ] + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([SavedQueryRule()]) + check_only_one_warning_with_message( + manifest_validator.validate_semantic_manifest(manifest), + "is not a valid granularity name", + ) + + def test_invalid_group_by_element_in_saved_query( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: From fe81c4212a7f93d3e131f1ff570db3d05accb599 Mon Sep 17 00:00:00 2001 From: Patrick Yost Date: Wed, 23 Oct 2024 17:43:51 -0700 Subject: [PATCH 2/8] fix lint --- dbt_semantic_interfaces/validations/saved_query.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index fe23c68d..e062b39f 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -146,12 +146,17 @@ def _check_where(saved_query: SavedQuery, custom_granularity_names: list[str]) - def _check_where_timespine( saved_query: SavedQuery, custom_granularity_names: list[str] ) -> Sequence[ValidationIssue]: + where_param = saved_query.query_params.where + if where_param is None: + return [] + issues: List[ValidationIssue] = [] valid_granularity_names = [ standard_granularity.name for standard_granularity in TimeGranularity ] + custom_granularity_names - for where_filter in saved_query.query_params.where.where_filters: + + for where_filter in where_param.where_filters: for time_dim_call_parameter_set in where_filter.call_parameter_sets.time_dimension_call_parameter_sets: if not time_dim_call_parameter_set.time_granularity_name: continue From c5fca352d7a46239624f935ebde1d9a156dd4260 Mon Sep 17 00:00:00 2001 From: Patrick Yost Date: Wed, 23 Oct 2024 18:04:57 -0700 Subject: [PATCH 3/8] Added changie file --- .changes/unreleased/Under the Hood-20241023-180425.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Under the Hood-20241023-180425.yaml diff --git a/.changes/unreleased/Under the Hood-20241023-180425.yaml b/.changes/unreleased/Under the Hood-20241023-180425.yaml new file mode 100644 index 00000000..33ebed54 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241023-180425.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Added validation warnings for invalid time spines in where filters of saved queries. +time: 2024-10-23T18:04:25.235887-07:00 +custom: + Author: theyostalservice + Issue: "360" From ad329c5e862c934d4259a417827e7a536f73daf8 Mon Sep 17 00:00:00 2001 From: Patrick Yost Date: Sun, 27 Oct 2024 17:44:57 -0700 Subject: [PATCH 4/8] Centralize Where filter validation + comments This is not *exactly* what was described in Courtney's comments, but I think there was a little roughness in that plan once I began implementing (but I'm happy to change course if reviewers disagree!). This commit promotes the `WhereFiltersAreParseable` to be `WhereFiltersAreParseableRule` (a free-standing 'rule' in a separate file). It was weird to be passing things in from other classes but somehow centralizing the manifest, so instead, I just moved ALL of the relevant checks here. This moves the tests for where filters to a new file specifically for this rule (again, I'm open to the idea that this would be better just divided amongst the existing tests, but they share so much conceptually that it seems nice to group them together and to have a test that is pointed 1:1 at a rule where possible). Finally, I also moved some of the test support functions (`check_only_one_error_with_message`, for example) to `dbt_semantic_interfaces.test_utils` because they seem useful and reusable. --- .../Under the Hood-20241023-180425.yaml | 2 +- dbt_semantic_interfaces/test_utils.py | 39 +++ .../validations/metrics.py | 175 +--------- .../validations/saved_query.py | 70 ---- .../semantic_manifest_validator.py | 6 +- .../validations/where_filters.py | 268 +++++++++++++++ tests/validations/test_metrics.py | 142 -------- tests/validations/test_saved_query.py | 111 +------ .../test_where_filters_are_parseable.py | 309 ++++++++++++++++++ 9 files changed, 623 insertions(+), 499 deletions(-) create mode 100644 dbt_semantic_interfaces/validations/where_filters.py create mode 100644 tests/validations/test_where_filters_are_parseable.py diff --git a/.changes/unreleased/Under the Hood-20241023-180425.yaml b/.changes/unreleased/Under the Hood-20241023-180425.yaml index 33ebed54..25b933b8 100644 --- a/.changes/unreleased/Under the Hood-20241023-180425.yaml +++ b/.changes/unreleased/Under the Hood-20241023-180425.yaml @@ -1,5 +1,5 @@ kind: Under the Hood -body: Added validation warnings for invalid time spines in where filters of saved queries. +body: Added validation warnings for invalid granularity names in where filters of saved queries. time: 2024-10-23T18:04:25.235887-07:00 custom: Author: theyostalservice diff --git a/dbt_semantic_interfaces/test_utils.py b/dbt_semantic_interfaces/test_utils.py index ced0cd21..88cdd625 100644 --- a/dbt_semantic_interfaces/test_utils.py +++ b/dbt_semantic_interfaces/test_utils.py @@ -25,6 +25,9 @@ ) from dbt_semantic_interfaces.parsing.objects import YamlConfigFile from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity +from dbt_semantic_interfaces.validations.validator_helpers import ( + SemanticManifestValidationResults, +) logger = logging.getLogger(__name__) @@ -169,3 +172,39 @@ def semantic_model_with_guaranteed_meta( dimensions=dimensions, metadata=metadata, ) + + +def check_only_one_error_with_message( # noqa: D + results: SemanticManifestValidationResults, target_message: str +) -> None: + assert len(results.warnings) == 0 + assert len(results.errors) == 1 + assert len(results.future_errors) == 0 + + found_match = results.errors[0].message.find(target_message) != -1 + # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. + assert { + "expected": target_message, + "actual": results.errors[0].message, + } and found_match + + +def check_only_one_warning_with_message( # noqa: D + results: SemanticManifestValidationResults, target_message: str +) -> None: + assert len(results.errors) == 0 + assert len(results.warnings) == 1 + assert len(results.future_errors) == 0 + + found_match = results.warnings[0].message.find(target_message) != -1 + # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. + assert { + "expected": target_message, + "actual": results.warnings[0].message, + } and found_match + + +def check_no_errors_or_warnings(results: SemanticManifestValidationResults) -> None: # noqa: D + assert len(results.errors) == 0 + assert len(results.warnings) == 0 + assert len(results.future_errors) == 0 diff --git a/dbt_semantic_interfaces/validations/metrics.py b/dbt_semantic_interfaces/validations/metrics.py index f0eb0dff..da9153a8 100644 --- a/dbt_semantic_interfaces/validations/metrics.py +++ b/dbt_semantic_interfaces/validations/metrics.py @@ -1,7 +1,6 @@ import traceback -from typing import Dict, Generic, List, Optional, Sequence, Tuple +from typing import Dict, Generic, List, Optional, Sequence -from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets from dbt_semantic_interfaces.errors import ParsingException from dbt_semantic_interfaces.implementations.metric import ( PydanticMetric, @@ -35,7 +34,6 @@ ValidationError, ValidationIssue, ValidationWarning, - generate_exception_issue, validate_safely, ) @@ -244,177 +242,6 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati return issues -class WhereFiltersAreParseable(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): - """Validates that all Metric WhereFilters are parseable.""" - - @staticmethod - def _validate_time_granularity_names( - context: MetricContext, - filter_expression_parameter_sets: Sequence[Tuple[str, FilterCallParameterSets]], - custom_granularity_names: List[str], - ) -> Sequence[ValidationIssue]: - issues: List[ValidationIssue] = [] - - valid_granularity_names = [ - standard_granularity.value for standard_granularity in TimeGranularity - ] + custom_granularity_names - for _, parameter_set in filter_expression_parameter_sets: - for time_dim_call_parameter_set in parameter_set.time_dimension_call_parameter_sets: - if not time_dim_call_parameter_set.time_granularity_name: - continue - if time_dim_call_parameter_set.time_granularity_name.lower() not in valid_granularity_names: - issues.append( - ValidationWarning( - context=context, - message=f"Filter for metric `{context.metric.metric_name}` is not valid. " - f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. " - f"Valid granularity options: {valid_granularity_names}", - ) - ) - return issues - - @staticmethod - @validate_safely( - whats_being_done="running model validation ensuring a metric's filter properties are configured properly" - ) - def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Sequence[ValidationIssue]: # noqa: D - issues: List[ValidationIssue] = [] - context = MetricContext( - file_context=FileContext.from_metadata(metadata=metric.metadata), - metric=MetricModelReference(metric_name=metric.name), - ) - - if metric.filter is not None: - try: - metric.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse filter of metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - if metric.type_params: - measure = metric.type_params.measure - if measure is not None and measure.filter is not None: - try: - measure.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse filter of measure input `{measure.name}` " - f"on metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - numerator = metric.type_params.numerator - if numerator is not None and numerator.filter is not None: - try: - numerator.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse the numerator filter on metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - denominator = metric.type_params.denominator - if denominator is not None and denominator.filter is not None: - try: - denominator.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse the denominator filter on metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - for input_metric in metric.type_params.metrics or []: - if input_metric.filter is not None: - try: - input_metric.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse filter for input metric `{input_metric.name}` " - f"on metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - # TODO: Are saved query filters being validated? Task: SL-2932 - return issues - - @staticmethod - @validate_safely(whats_being_done="running manifest validation ensuring all metric where filters are parseable") - def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D - issues: List[ValidationIssue] = [] - custom_granularity_names = [ - granularity.name - for time_spine in semantic_manifest.project_configuration.time_spines - for granularity in time_spine.custom_granularities - ] - for metric in semantic_manifest.metrics or []: - issues += WhereFiltersAreParseable._validate_metric( - metric=metric, custom_granularity_names=custom_granularity_names - ) - return issues - - class ConversionMetricRule(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): """Checks that conversion metrics are configured properly.""" diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index e062b39f..c805aa8f 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -23,7 +23,6 @@ ) from dbt_semantic_interfaces.protocols import SemanticManifestT from dbt_semantic_interfaces.protocols.saved_query import SavedQuery -from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from dbt_semantic_interfaces.validations.validator_helpers import ( FileContext, SavedQueryContext, @@ -31,7 +30,6 @@ SemanticManifestValidationRule, ValidationError, ValidationIssue, - ValidationWarning, generate_exception_issue, validate_safely, ) @@ -114,68 +112,6 @@ def _check_metrics(valid_metric_names: Set[str], saved_query: SavedQuery) -> Seq ) return issues - @staticmethod - @validate_safely("Validate the where field in a saved query.") - def _check_where(saved_query: SavedQuery, custom_granularity_names: list[str]) -> Sequence[ValidationIssue]: - issues: List[ValidationIssue] = [] - if saved_query.query_params.where is None: - return issues - for where_filter in saved_query.query_params.where.where_filters: - try: - where_filter.call_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse a filter in saved query `{saved_query.name}`", - e=e, - context=SavedQueryContext( - file_context=FileContext.from_metadata(metadata=saved_query.metadata), - element_type=SavedQueryElementType.WHERE, - element_value=where_filter.where_sql_template, - ), - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += SavedQueryRule._check_where_timespine(saved_query, custom_granularity_names) - - return issues - - def _check_where_timespine( - saved_query: SavedQuery, custom_granularity_names: list[str] - ) -> Sequence[ValidationIssue]: - where_param = saved_query.query_params.where - if where_param is None: - return [] - - issues: List[ValidationIssue] = [] - - valid_granularity_names = [ - standard_granularity.name for standard_granularity in TimeGranularity - ] + custom_granularity_names - - for where_filter in where_param.where_filters: - for time_dim_call_parameter_set in where_filter.call_parameter_sets.time_dimension_call_parameter_sets: - if not time_dim_call_parameter_set.time_granularity_name: - continue - if time_dim_call_parameter_set.time_granularity_name not in valid_granularity_names: - issues.append( - ValidationWarning( - context=SavedQueryContext( - file_context=FileContext.from_metadata(metadata=saved_query.metadata), - element_type=SavedQueryElementType.WHERE, - element_value=where_filter.where_sql_template, - ), - # message=f"Filter for metric `{context.metric.metric_name}` is not valid. " - message=f"Filter for saved query `{saved_query.name}` is not valid. " - f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. " - f"Valid granularity options: {valid_granularity_names}", - ) - ) - return issues - @staticmethod def _parse_query_item( saved_query: SavedQuery, @@ -317,11 +253,6 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati for entity in semantic_model.entities: valid_group_by_element_names.add(entity.name) - custom_granularity_names = [ - granularity.name - for time_spine in semantic_manifest.project_configuration.time_spines - for granularity in time_spine.custom_granularities - ] for saved_query in semantic_manifest.saved_queries: issues += SavedQueryRule._check_metrics( valid_metric_names=valid_metric_names, @@ -331,7 +262,6 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati valid_group_by_element_names=valid_group_by_element_names, saved_query=saved_query, ) - issues += SavedQueryRule._check_where(saved_query, custom_granularity_names) issues += SavedQueryRule._check_order_by(saved_query) issues += SavedQueryRule._check_limit(saved_query) return issues diff --git a/dbt_semantic_interfaces/validations/semantic_manifest_validator.py b/dbt_semantic_interfaces/validations/semantic_manifest_validator.py index 26f81452..cfad51ed 100644 --- a/dbt_semantic_interfaces/validations/semantic_manifest_validator.py +++ b/dbt_semantic_interfaces/validations/semantic_manifest_validator.py @@ -27,7 +27,6 @@ ConversionMetricRule, CumulativeMetricRule, DerivedMetricRule, - WhereFiltersAreParseable, ) from dbt_semantic_interfaces.validations.non_empty import NonEmptyRule from dbt_semantic_interfaces.validations.primary_entity import PrimaryEntityRule @@ -47,6 +46,9 @@ SemanticManifestValidationResults, SemanticManifestValidationRule, ) +from dbt_semantic_interfaces.validations.where_filters import ( + WhereFiltersAreParseableRule, +) logger = logging.getLogger(__name__) @@ -86,7 +88,7 @@ class SemanticManifestValidator(Generic[SemanticManifestT]): SemanticModelDefaultsRule[SemanticManifestT](), PrimaryEntityRule[SemanticManifestT](), PrimaryEntityDimensionPairs[SemanticManifestT](), - WhereFiltersAreParseable[SemanticManifestT](), + WhereFiltersAreParseableRule[SemanticManifestT](), SavedQueryRule[SemanticManifestT](), MetricLabelsRule[SemanticManifestT](), SemanticModelLabelsRule[SemanticManifestT](), diff --git a/dbt_semantic_interfaces/validations/where_filters.py b/dbt_semantic_interfaces/validations/where_filters.py new file mode 100644 index 00000000..333be457 --- /dev/null +++ b/dbt_semantic_interfaces/validations/where_filters.py @@ -0,0 +1,268 @@ +import traceback +from typing import Generic, List, Sequence, Tuple + +from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets +from dbt_semantic_interfaces.protocols import Metric, SemanticManifestT +from dbt_semantic_interfaces.protocols.saved_query import SavedQuery +from dbt_semantic_interfaces.references import MetricModelReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.validations.validator_helpers import ( + FileContext, + MetricContext, + SavedQueryContext, + SavedQueryElementType, + SemanticManifestValidationRule, + ValidationContext, + ValidationIssue, + ValidationWarning, + generate_exception_issue, + validate_safely, +) + + +class WhereFiltersAreParseableRule(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): + """Validates that all Metric WhereFilters are parseable.""" + + @staticmethod + def _validate_time_granularity_names_impl( + location_label_for_errors: str, + filter_call_parameter_sets_with_context: Sequence[Tuple[ValidationContext, FilterCallParameterSets]], + custom_granularity_names: List[str], + ) -> Sequence[ValidationIssue]: + issues: List[ValidationIssue] = [] + + valid_granularity_names = [ + standard_granularity.value for standard_granularity in TimeGranularity + ] + custom_granularity_names + + for context, parameter_set in filter_call_parameter_sets_with_context: + for time_dim_call_parameter_set in parameter_set.time_dimension_call_parameter_sets: + if not time_dim_call_parameter_set.time_granularity_name: + continue + if time_dim_call_parameter_set.time_granularity_name.lower() not in valid_granularity_names: + issues.append( + ValidationWarning( + context=context, + message=f"Filter for `{location_label_for_errors}` is not valid. " + f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. " + f"Valid granularity options: {valid_granularity_names}", + ) + ) + return issues + + @staticmethod + def _validate_time_granularity_names_for_saved_query( + saved_query: SavedQuery, custom_granularity_names: List[str] + ) -> Sequence[ValidationIssue]: + where_param = saved_query.query_params.where + if where_param is None: + return [] + + return WhereFiltersAreParseableRule._validate_time_granularity_names_impl( + location_label_for_errors="saved query `{saved_query.name}`", + filter_call_parameter_sets_with_context=[ + ( + SavedQueryContext( + file_context=FileContext.from_metadata(metadata=saved_query.metadata), + element_type=SavedQueryElementType.WHERE, + element_value=where_filter.where_sql_template, + ), + where_filter.call_parameter_sets, + ) + for where_filter in where_param.where_filters + ], + custom_granularity_names=custom_granularity_names, + ) + + @staticmethod + def _validate_time_granularity_names_for_metric( + context: MetricContext, + filter_expression_parameter_sets: Sequence[Tuple[str, FilterCallParameterSets]], + custom_granularity_names: List[str], + ) -> Sequence[ValidationIssue]: + return WhereFiltersAreParseableRule._validate_time_granularity_names_impl( + location_label_for_errors="metric `{context.metric.metric_name}`", + filter_call_parameter_sets_with_context=[ + ( + context, + param_set[1], + ) + for param_set in filter_expression_parameter_sets + ], + custom_granularity_names=custom_granularity_names, + ) + + @staticmethod + @validate_safely("validating the where field in a saved query.") + def _validate_saved_query( + saved_query: SavedQuery, custom_granularity_names: List[str] + ) -> Sequence[ValidationIssue]: + issues: List[ValidationIssue] = [] + if saved_query.query_params.where is None: + return issues + for where_filter in saved_query.query_params.where.where_filters: + try: + where_filter.call_parameter_sets + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse a filter in saved query `{saved_query.name}`", + e=e, + context=SavedQueryContext( + file_context=FileContext.from_metadata(metadata=saved_query.metadata), + element_type=SavedQueryElementType.WHERE, + element_value=where_filter.where_sql_template, + ), + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_saved_query( + saved_query, custom_granularity_names + ) + + return issues + + @staticmethod + @validate_safely( + whats_being_done="running model validation ensuring a metric's filter properties are configured properly" + ) + def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Sequence[ValidationIssue]: # noqa: D + issues: List[ValidationIssue] = [] + context = MetricContext( + file_context=FileContext.from_metadata(metadata=metric.metadata), + metric=MetricModelReference(metric_name=metric.name), + ) + + if metric.filter is not None: + try: + metric.filter.filter_expression_parameter_sets + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse filter of metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets, + custom_granularity_names=custom_granularity_names, + ) + + if metric.type_params: + measure = metric.type_params.measure + if measure is not None and measure.filter is not None: + try: + measure.filter.filter_expression_parameter_sets + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse filter of measure input `{measure.name}` " + f"on metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets, + custom_granularity_names=custom_granularity_names, + ) + + numerator = metric.type_params.numerator + if numerator is not None and numerator.filter is not None: + try: + numerator.filter.filter_expression_parameter_sets + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse the numerator filter on metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets, + custom_granularity_names=custom_granularity_names, + ) + + denominator = metric.type_params.denominator + if denominator is not None and denominator.filter is not None: + try: + denominator.filter.filter_expression_parameter_sets + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse the denominator filter on metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets, + custom_granularity_names=custom_granularity_names, + ) + + for input_metric in metric.type_params.metrics or []: + if input_metric.filter is not None: + try: + input_metric.filter.filter_expression_parameter_sets + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse filter for input metric `{input_metric.name}` " + f"on metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets, + custom_granularity_names=custom_granularity_names, + ) + return issues + + @staticmethod + @validate_safely(whats_being_done="running manifest validation ensuring all metric where filters are parseable") + def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D + issues: List[ValidationIssue] = [] + custom_granularity_names = [ + granularity.name + for time_spine in semantic_manifest.project_configuration.time_spines + for granularity in time_spine.custom_granularities + ] + for metric in semantic_manifest.metrics or []: + issues += WhereFiltersAreParseableRule._validate_metric( + metric=metric, custom_granularity_names=custom_granularity_names + ) + for saved_query in semantic_manifest.saved_queries: + issues += WhereFiltersAreParseableRule._validate_saved_query(saved_query, custom_granularity_names) + + return issues diff --git a/tests/validations/test_metrics.py b/tests/validations/test_metrics.py index efafeef0..c96d23b3 100644 --- a/tests/validations/test_metrics.py +++ b/tests/validations/test_metrics.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import List, Tuple import pytest @@ -31,7 +30,6 @@ TimeDimensionReference, ) from dbt_semantic_interfaces.test_utils import ( - find_metric_with, metric_with_guaranteed_meta, semantic_model_with_guaranteed_meta, ) @@ -48,7 +46,6 @@ CumulativeMetricRule, DerivedMetricRule, MetricTimeGranularityRule, - WhereFiltersAreParseable, ) from dbt_semantic_interfaces.validations.semantic_manifest_validator import ( SemanticManifestValidator, @@ -345,145 +342,6 @@ def test_derived_metric() -> None: # noqa: D check_error_in_issues(error_substrings=expected_substrings, issues=build_issues) -def test_where_filter_validations_happy( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - results = validator.validate_semantic_manifest(simple_semantic_manifest__with_primary_transforms) - assert not results.has_blocking_issues - - -def test_where_filter_validations_bad_base_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with(manifest, lambda metric: metric.filter is not None) - assert metric.filter is not None - assert len(metric.filter.where_filters) > 0 - metric.filter.where_filters[0].where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises(SemanticManifestValidationException, match=f"trying to parse filter of metric `{metric.name}`"): - validator.checked_validations(manifest) - - -def test_where_filter_validations_bad_measure_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, lambda metric: metric.type_params is not None and metric.type_params.measure is not None - ) - assert metric.type_params.measure is not None - metric.type_params.measure.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises( - SemanticManifestValidationException, - match=f"trying to parse filter of measure input `{metric.type_params.measure.name}` on metric `{metric.name}`", - ): - validator.checked_validations(manifest) - - -def test_where_filter_validations_bad_numerator_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, lambda metric: metric.type_params is not None and metric.type_params.numerator is not None - ) - assert metric.type_params.numerator is not None - metric.type_params.numerator.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises( - SemanticManifestValidationException, match=f"trying to parse the numerator filter on metric `{metric.name}`" - ): - validator.checked_validations(manifest) - - -def test_where_filter_validations_bad_denominator_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, lambda metric: metric.type_params is not None and metric.type_params.denominator is not None - ) - assert metric.type_params.denominator is not None - metric.type_params.denominator.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises( - SemanticManifestValidationException, match=f"trying to parse the denominator filter on metric `{metric.name}`" - ): - validator.checked_validations(manifest) - - -def test_where_filter_validations_bad_input_metric_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, - lambda metric: metric.type_params is not None - and metric.type_params.metrics is not None - and len(metric.type_params.metrics) > 0, - ) - assert metric.type_params.metrics is not None - input_metric = metric.type_params.metrics[0] - input_metric.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises( - SemanticManifestValidationException, - match=f"trying to parse filter for input metric `{input_metric.name}` on metric `{metric.name}`", - ): - validator.checked_validations(manifest) - - -def test_where_filter_validations_invalid_granularity( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, - lambda metric: metric.type_params is not None - and metric.type_params.metrics is not None - and len(metric.type_params.metrics) > 0, - ) - assert metric.type_params.metrics is not None - input_metric = metric.type_params.metrics[0] - input_metric.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), - PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'month') }}"), - PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'MONTH') }}"), - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - issues = validator.validate_semantic_manifest(manifest) - assert not issues.has_blocking_issues - assert len(issues.warnings) == 1 - assert "`cool` is not a valid granularity name" in issues.warnings[0].message - - def test_conversion_metrics() -> None: # noqa: D base_measure_name = "base_measure" conversion_measure_name = "conversion_measure" diff --git a/tests/validations/test_saved_query.py b/tests/validations/test_saved_query.py index 56f50dea..566407f9 100644 --- a/tests/validations/test_saved_query.py +++ b/tests/validations/test_saved_query.py @@ -12,47 +12,15 @@ from dbt_semantic_interfaces.implementations.semantic_manifest import ( PydanticSemanticManifest, ) +from dbt_semantic_interfaces.test_utils import check_only_one_error_with_message from dbt_semantic_interfaces.validations.saved_query import SavedQueryRule from dbt_semantic_interfaces.validations.semantic_manifest_validator import ( SemanticManifestValidator, ) -from dbt_semantic_interfaces.validations.validator_helpers import ( - SemanticManifestValidationResults, -) logger = logging.getLogger(__name__) -def check_only_one_error_with_message( # noqa: D - results: SemanticManifestValidationResults, target_message: str -) -> None: - assert len(results.warnings) == 0 - assert len(results.errors) == 1 - assert len(results.future_errors) == 0 - - found_match = results.errors[0].message.find(target_message) != -1 - # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. - assert { - "expected": target_message, - "actual": results.errors[0].message, - } and found_match - - -def check_only_one_warning_with_message( # noqa: D - results: SemanticManifestValidationResults, target_message: str -) -> None: - assert len(results.warnings) == 1 - assert len(results.errors) == 0 - assert len(results.future_errors) == 0 - - found_match = results.warnings[0].message.find(target_message) != -1 - # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. - assert { - "expected": target_message, - "actual": results.warnings[0].message, - } and found_match - - def test_invalid_metric_in_saved_query( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: @@ -77,59 +45,6 @@ def test_invalid_metric_in_saved_query( # noqa: D ) -def test_invalid_where_in_saved_query( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) - manifest.saved_queries = [ - PydanticSavedQuery( - name="Example Saved Query", - description="Example description.", - query_params=PydanticSavedQueryQueryParams( - metrics=["bookings"], - group_by=["Dimension('booking__is_instant')"], - where=PydanticWhereFilterIntersection( - where_filters=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], - ), - ), - ), - ] - - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([SavedQueryRule()]) - check_only_one_error_with_message( - manifest_validator.validate_semantic_manifest(manifest), - "trying to parse a filter in saved query", - ) - - -def test_where_filter_validations_invalid_granularity( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) - - manifest.saved_queries = [ - PydanticSavedQuery( - name="Example Saved Query", - description="Example description.", - query_params=PydanticSavedQueryQueryParams( - metrics=["bookings"], - group_by=["Dimension('booking__is_instant')"], - where=PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), - ] - ), - ), - ), - ] - - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([SavedQueryRule()]) - check_only_one_warning_with_message( - manifest_validator.validate_semantic_manifest(manifest), - "is not a valid granularity name", - ) - - def test_invalid_group_by_element_in_saved_query( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: @@ -180,30 +95,6 @@ def test_invalid_group_by_format_in_saved_query( # noqa: D ) -def test_metric_filter_error( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) - manifest.saved_queries = [ - PydanticSavedQuery( - name="Example Saved Query", - description="Example description.", - query_params=PydanticSavedQueryQueryParams( - metrics=["listings"], - where=PydanticWhereFilterIntersection( - where_filters=[PydanticWhereFilter(where_sql_template="{{ Metric('bookings') }} > 2")], - ), - ), - ), - ] - - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([SavedQueryRule()]) - check_only_one_error_with_message( - manifest_validator.validate_semantic_manifest(manifest), - "An error occurred while trying to parse a filter in saved query", - ) - - def test_metric_filter_success( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: diff --git a/tests/validations/test_where_filters_are_parseable.py b/tests/validations/test_where_filters_are_parseable.py new file mode 100644 index 00000000..15b75986 --- /dev/null +++ b/tests/validations/test_where_filters_are_parseable.py @@ -0,0 +1,309 @@ +import copy +import logging + +import pytest + +from dbt_semantic_interfaces.implementations.filters.where_filter import ( + PydanticWhereFilter, + PydanticWhereFilterIntersection, +) +from dbt_semantic_interfaces.implementations.saved_query import ( + PydanticSavedQuery, + PydanticSavedQueryQueryParams, +) +from dbt_semantic_interfaces.implementations.semantic_manifest import ( + PydanticSemanticManifest, +) +from dbt_semantic_interfaces.test_utils import ( + check_no_errors_or_warnings, + check_only_one_error_with_message, + check_only_one_warning_with_message, + find_metric_with, +) +from dbt_semantic_interfaces.validations.semantic_manifest_validator import ( + SemanticManifestValidator, +) +from dbt_semantic_interfaces.validations.validator_helpers import ( + SemanticManifestValidationException, +) +from dbt_semantic_interfaces.validations.where_filters import ( + WhereFiltersAreParseableRule, +) + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------------------ +# Metric validations +# ------------------------------------------------------------------------------ + + +def test_metric_where_filter_validations_happy( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + results = validator.validate_semantic_manifest(simple_semantic_manifest__with_primary_transforms) + assert not results.has_blocking_issues + + +def test_where_filter_validations_bad_base_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with(manifest, lambda metric: metric.filter is not None) + assert metric.filter is not None + assert len(metric.filter.where_filters) > 0 + metric.filter.where_filters[0].where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + with pytest.raises(SemanticManifestValidationException, match=f"trying to parse filter of metric `{metric.name}`"): + validator.checked_validations(manifest) + + +def test_where_filter_validations_bad_measure_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, lambda metric: metric.type_params is not None and metric.type_params.measure is not None + ) + assert metric.type_params.measure is not None + metric.type_params.measure.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + with pytest.raises( + SemanticManifestValidationException, + match=f"trying to parse filter of measure input `{metric.type_params.measure.name}` on metric `{metric.name}`", + ): + validator.checked_validations(manifest) + + +def test_where_filter_validations_bad_numerator_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, lambda metric: metric.type_params is not None and metric.type_params.numerator is not None + ) + assert metric.type_params.numerator is not None + metric.type_params.numerator.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + with pytest.raises( + SemanticManifestValidationException, match=f"trying to parse the numerator filter on metric `{metric.name}`" + ): + validator.checked_validations(manifest) + + +def test_where_filter_validations_bad_denominator_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, lambda metric: metric.type_params is not None and metric.type_params.denominator is not None + ) + assert metric.type_params.denominator is not None + metric.type_params.denominator.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + with pytest.raises( + SemanticManifestValidationException, match=f"trying to parse the denominator filter on metric `{metric.name}`" + ): + validator.checked_validations(manifest) + + +def test_where_filter_validations_bad_input_metric_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, + lambda metric: metric.type_params is not None + and metric.type_params.metrics is not None + and len(metric.type_params.metrics) > 0, + ) + assert metric.type_params.metrics is not None + input_metric = metric.type_params.metrics[0] + input_metric.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + with pytest.raises( + SemanticManifestValidationException, + match=f"trying to parse filter for input metric `{input_metric.name}` on metric `{metric.name}`", + ): + validator.checked_validations(manifest) + + +def test_metric_where_filter_validations_invalid_granularity( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, + lambda metric: metric.type_params is not None + and metric.type_params.metrics is not None + and len(metric.type_params.metrics) > 0, + ) + assert metric.type_params.metrics is not None + input_metric = metric.type_params.metrics[0] + input_metric.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'month') }}"), + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'MONTH') }}"), + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + issues = validator.validate_semantic_manifest(manifest) + assert not issues.has_blocking_issues + assert len(issues.warnings) == 1 + assert "`cool` is not a valid granularity name" in issues.warnings[0].message + + +# ------------------------------------------------------------------------------ +# Saved Query validations +# ------------------------------------------------------------------------------ + + +def test_saved_query_with_happy_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'hour') }}"), + ] + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + check_no_errors_or_warnings(manifest_validator.validate_semantic_manifest(manifest)) + + +def test_saved_query_validates_granularity_name_despite_case( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'DAY') }}"), + ] + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + check_no_errors_or_warnings(manifest_validator.validate_semantic_manifest(manifest)) + + +def test_invalid_where_in_saved_query( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + check_only_one_error_with_message( + manifest_validator.validate_semantic_manifest(manifest), + "trying to parse a filter in saved query", + ) + + +def test_saved_query_where_filter_validations_invalid_granularity( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), + ] + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + check_only_one_warning_with_message( + manifest_validator.validate_semantic_manifest(manifest), + "is not a valid granularity name", + ) + + +def test_metric_filter_error( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["listings"], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Metric('bookings') }} > 2")], + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + check_only_one_error_with_message( + manifest_validator.validate_semantic_manifest(manifest), + "An error occurred while trying to parse a filter in saved query", + ) From b33ac7e26e8782aaa5940e91e41c72a44bd6f405 Mon Sep 17 00:00:00 2001 From: Patrick Yost Date: Wed, 30 Oct 2024 18:54:09 -0700 Subject: [PATCH 5/8] Addresses all comments. --- dbt_semantic_interfaces/test_utils.py | 70 ++++++--- .../semantic_manifest_validator.py | 6 +- .../validations/where_filters.py | 144 +++++++++--------- .../test_where_filters_are_parseable.py | 28 ++-- 4 files changed, 138 insertions(+), 110 deletions(-) diff --git a/dbt_semantic_interfaces/test_utils.py b/dbt_semantic_interfaces/test_utils.py index 88cdd625..fc9a01b2 100644 --- a/dbt_semantic_interfaces/test_utils.py +++ b/dbt_semantic_interfaces/test_utils.py @@ -27,6 +27,7 @@ from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity from dbt_semantic_interfaces.validations.validator_helpers import ( SemanticManifestValidationResults, + ValidationIssue, ) logger = logging.getLogger(__name__) @@ -174,37 +175,62 @@ def semantic_model_with_guaranteed_meta( ) -def check_only_one_error_with_message( # noqa: D - results: SemanticManifestValidationResults, target_message: str +def _assert_expected_validation_message( # noqa: D + issues: Sequence[ValidationIssue], + message_fragment: str, ) -> None: - assert len(results.warnings) == 0 - assert len(results.errors) == 1 - assert len(results.future_errors) == 0 - - found_match = results.errors[0].message.find(target_message) != -1 + found_match = any([issue.message.find(message_fragment) != -1 for issue in issues]) # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. assert { - "expected": target_message, - "actual": results.errors[0].message, + "expected": message_fragment, + "actual_messages": [issue.message for issue in issues], } and found_match -def check_only_one_warning_with_message( # noqa: D +def check_expected_issues( # noqa: D + results: SemanticManifestValidationResults, + num_expected_errors: int = 0, + num_expected_warnings: int = 0, + expected_error_msgs: Sequence[str] = [], + expected_warning_msgs: Sequence[str] = [], +) -> None: + """Validates the number, type, and content of ValidationIssues. + + Currently assumes zero future_errors as there are no future_errors + implemented, but this function can be expanded to cover those if needed. + """ + assert len(results.warnings) == num_expected_warnings + assert len(results.errors) == num_expected_errors + assert len(results.future_errors) == 0, "validation function expects zero future_errors to be implemented." + + for expected_error_msg in expected_error_msgs: + _assert_expected_validation_message(issues=results.errors, message_fragment=expected_error_msg) + for expected_warning_msg in expected_warning_msgs: + _assert_expected_validation_message(issues=results.warnings, message_fragment=expected_warning_msg) + + +def check_only_one_error_with_message( # noqa: D results: SemanticManifestValidationResults, target_message: str ) -> None: - assert len(results.errors) == 0 - assert len(results.warnings) == 1 - assert len(results.future_errors) == 0 + check_expected_issues( + results=results, + num_expected_errors=1, + expected_error_msgs=[target_message], + ) - found_match = results.warnings[0].message.find(target_message) != -1 - # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. - assert { - "expected": target_message, - "actual": results.warnings[0].message, - } and found_match + +def check_only_one_warning_with_message( # noqa: D + results: SemanticManifestValidationResults, target_message: str +) -> None: + check_expected_issues( + results=results, + num_expected_warnings=1, + expected_warning_msgs=[target_message], + ) def check_no_errors_or_warnings(results: SemanticManifestValidationResults) -> None: # noqa: D - assert len(results.errors) == 0 - assert len(results.warnings) == 0 - assert len(results.future_errors) == 0 + # no num arguments required since all defaults are zero + check_expected_issues( + results=results, + ) diff --git a/dbt_semantic_interfaces/validations/semantic_manifest_validator.py b/dbt_semantic_interfaces/validations/semantic_manifest_validator.py index cfad51ed..44de62cf 100644 --- a/dbt_semantic_interfaces/validations/semantic_manifest_validator.py +++ b/dbt_semantic_interfaces/validations/semantic_manifest_validator.py @@ -46,9 +46,7 @@ SemanticManifestValidationResults, SemanticManifestValidationRule, ) -from dbt_semantic_interfaces.validations.where_filters import ( - WhereFiltersAreParseableRule, -) +from dbt_semantic_interfaces.validations.where_filters import WhereFiltersAreParseable logger = logging.getLogger(__name__) @@ -88,7 +86,7 @@ class SemanticManifestValidator(Generic[SemanticManifestT]): SemanticModelDefaultsRule[SemanticManifestT](), PrimaryEntityRule[SemanticManifestT](), PrimaryEntityDimensionPairs[SemanticManifestT](), - WhereFiltersAreParseableRule[SemanticManifestT](), + WhereFiltersAreParseable[SemanticManifestT](), SavedQueryRule[SemanticManifestT](), MetricLabelsRule[SemanticManifestT](), SemanticModelLabelsRule[SemanticManifestT](), diff --git a/dbt_semantic_interfaces/validations/where_filters.py b/dbt_semantic_interfaces/validations/where_filters.py index 333be457..01ba5364 100644 --- a/dbt_semantic_interfaces/validations/where_filters.py +++ b/dbt_semantic_interfaces/validations/where_filters.py @@ -1,4 +1,5 @@ import traceback +from enum import StrEnum, auto from typing import Generic, List, Sequence, Tuple from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets @@ -20,83 +21,84 @@ ) -class WhereFiltersAreParseableRule(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): - """Validates that all Metric WhereFilters are parseable.""" +class SemanticManifestNodeType(StrEnum): + """Types of objects to validate (used for validation messages).""" + + SAVED_QUERY = auto() + METRIC = auto() + + +class WhereFiltersAreParseable(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): + """Validates that all WhereFilters are parseable.""" @staticmethod - def _validate_time_granularity_names_impl( - location_label_for_errors: str, - filter_call_parameter_sets_with_context: Sequence[Tuple[ValidationContext, FilterCallParameterSets]], - custom_granularity_names: List[str], + def _validate_time_granularity_names( + element_name: str, + object_type: SemanticManifestNodeType, + context: ValidationContext, + filter_call_param_sets: FilterCallParameterSets, + valid_granularity_names: List[str], ) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] - valid_granularity_names = [ - standard_granularity.value for standard_granularity in TimeGranularity - ] + custom_granularity_names - - for context, parameter_set in filter_call_parameter_sets_with_context: - for time_dim_call_parameter_set in parameter_set.time_dimension_call_parameter_sets: - if not time_dim_call_parameter_set.time_granularity_name: - continue - if time_dim_call_parameter_set.time_granularity_name.lower() not in valid_granularity_names: - issues.append( - ValidationWarning( - context=context, - message=f"Filter for `{location_label_for_errors}` is not valid. " - f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. " - f"Valid granularity options: {valid_granularity_names}", - ) + for time_dim_call_parameter_set in filter_call_param_sets.time_dimension_call_parameter_sets: + if not time_dim_call_parameter_set.time_granularity_name: + continue + if time_dim_call_parameter_set.time_granularity_name.lower() not in valid_granularity_names: + issues.append( + ValidationWarning( + context=context, + message=f"Filter for {object_type} `{element_name}` is not valid. " + f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. " + f"Valid granularity options: {valid_granularity_names}", ) + ) return issues @staticmethod def _validate_time_granularity_names_for_saved_query( - saved_query: SavedQuery, custom_granularity_names: List[str] + saved_query: SavedQuery, valid_granularity_names: List[str] ) -> Sequence[ValidationIssue]: where_param = saved_query.query_params.where if where_param is None: return [] - return WhereFiltersAreParseableRule._validate_time_granularity_names_impl( - location_label_for_errors="saved query `{saved_query.name}`", - filter_call_parameter_sets_with_context=[ - ( - SavedQueryContext( - file_context=FileContext.from_metadata(metadata=saved_query.metadata), - element_type=SavedQueryElementType.WHERE, - element_value=where_filter.where_sql_template, - ), - where_filter.call_parameter_sets, - ) - for where_filter in where_param.where_filters - ], - custom_granularity_names=custom_granularity_names, - ) + issues: List[ValidationIssue] = [] + for where_filter in where_param.where_filters: + issues += WhereFiltersAreParseable._validate_time_granularity_names( + element_name=saved_query.name, + object_type=SemanticManifestNodeType.SAVED_QUERY, + context=SavedQueryContext( + file_context=FileContext.from_metadata(metadata=saved_query.metadata), + element_type=SavedQueryElementType.WHERE, + element_value=where_filter.where_sql_template, + ), + filter_call_param_sets=where_filter.call_parameter_sets, + valid_granularity_names=valid_granularity_names, + ) + + return issues @staticmethod def _validate_time_granularity_names_for_metric( context: MetricContext, filter_expression_parameter_sets: Sequence[Tuple[str, FilterCallParameterSets]], - custom_granularity_names: List[str], + valid_granularity_names: List[str], ) -> Sequence[ValidationIssue]: - return WhereFiltersAreParseableRule._validate_time_granularity_names_impl( - location_label_for_errors="metric `{context.metric.metric_name}`", - filter_call_parameter_sets_with_context=[ - ( - context, - param_set[1], - ) - for param_set in filter_expression_parameter_sets - ], - custom_granularity_names=custom_granularity_names, - ) + issues: List[ValidationIssue] = [] + for _, param_set in filter_expression_parameter_sets: + issues += WhereFiltersAreParseable._validate_time_granularity_names( + element_name=context.metric.metric_name, + object_type=SemanticManifestNodeType.METRIC, + context=context, + filter_call_param_sets=param_set, + valid_granularity_names=valid_granularity_names, + ) + return issues @staticmethod @validate_safely("validating the where field in a saved query.") - def _validate_saved_query( - saved_query: SavedQuery, custom_granularity_names: List[str] - ) -> Sequence[ValidationIssue]: + def _validate_saved_query(saved_query: SavedQuery, valid_granularity_names: List[str]) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] if saved_query.query_params.where is None: return issues @@ -119,8 +121,8 @@ def _validate_saved_query( ) ) else: - issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_saved_query( - saved_query, custom_granularity_names + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_saved_query( + saved_query, valid_granularity_names ) return issues @@ -129,7 +131,7 @@ def _validate_saved_query( @validate_safely( whats_being_done="running model validation ensuring a metric's filter properties are configured properly" ) - def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Sequence[ValidationIssue]: # noqa: D + def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequence[ValidationIssue]: # noqa: D issues: List[ValidationIssue] = [] context = MetricContext( file_context=FileContext.from_metadata(metadata=metric.metadata), @@ -151,10 +153,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq ) ) else: - issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, + valid_granularity_names=valid_granularity_names, ) if metric.type_params: @@ -175,10 +177,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq ) ) else: - issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, + valid_granularity_names=valid_granularity_names, ) numerator = metric.type_params.numerator @@ -197,10 +199,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq ) ) else: - issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, + valid_granularity_names=valid_granularity_names, ) denominator = metric.type_params.denominator @@ -219,10 +221,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq ) ) else: - issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, + valid_granularity_names=valid_granularity_names, ) for input_metric in metric.type_params.metrics or []: @@ -242,10 +244,10 @@ def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Seq ) ) else: - issues += WhereFiltersAreParseableRule._validate_time_granularity_names_for_metric( + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, + valid_granularity_names=valid_granularity_names, ) return issues @@ -258,11 +260,15 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati for time_spine in semantic_manifest.project_configuration.time_spines for granularity in time_spine.custom_granularities ] + valid_granularity_names = [ + standard_granularity.value for standard_granularity in TimeGranularity + ] + custom_granularity_names + for metric in semantic_manifest.metrics or []: - issues += WhereFiltersAreParseableRule._validate_metric( - metric=metric, custom_granularity_names=custom_granularity_names + issues += WhereFiltersAreParseable._validate_metric( + metric=metric, valid_granularity_names=valid_granularity_names ) for saved_query in semantic_manifest.saved_queries: - issues += WhereFiltersAreParseableRule._validate_saved_query(saved_query, custom_granularity_names) + issues += WhereFiltersAreParseable._validate_saved_query(saved_query, valid_granularity_names) return issues diff --git a/tests/validations/test_where_filters_are_parseable.py b/tests/validations/test_where_filters_are_parseable.py index 15b75986..ae948a99 100644 --- a/tests/validations/test_where_filters_are_parseable.py +++ b/tests/validations/test_where_filters_are_parseable.py @@ -26,9 +26,7 @@ from dbt_semantic_interfaces.validations.validator_helpers import ( SemanticManifestValidationException, ) -from dbt_semantic_interfaces.validations.where_filters import ( - WhereFiltersAreParseableRule, -) +from dbt_semantic_interfaces.validations.where_filters import WhereFiltersAreParseable logger = logging.getLogger(__name__) @@ -41,7 +39,7 @@ def test_metric_where_filter_validations_happy( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) results = validator.validate_semantic_manifest(simple_semantic_manifest__with_primary_transforms) assert not results.has_blocking_issues @@ -55,7 +53,7 @@ def test_where_filter_validations_bad_base_filter( # noqa: D assert metric.filter is not None assert len(metric.filter.where_filters) > 0 metric.filter.where_filters[0].where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises(SemanticManifestValidationException, match=f"trying to parse filter of metric `{metric.name}`"): validator.checked_validations(manifest) @@ -74,7 +72,7 @@ def test_where_filter_validations_bad_measure_filter( # noqa: D PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") ] ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises( SemanticManifestValidationException, match=f"trying to parse filter of measure input `{metric.type_params.measure.name}` on metric `{metric.name}`", @@ -96,7 +94,7 @@ def test_where_filter_validations_bad_numerator_filter( # noqa: D PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") ] ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises( SemanticManifestValidationException, match=f"trying to parse the numerator filter on metric `{metric.name}`" ): @@ -117,7 +115,7 @@ def test_where_filter_validations_bad_denominator_filter( # noqa: D PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") ] ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises( SemanticManifestValidationException, match=f"trying to parse the denominator filter on metric `{metric.name}`" ): @@ -142,7 +140,7 @@ def test_where_filter_validations_bad_input_metric_filter( # noqa: D PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") ] ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises( SemanticManifestValidationException, match=f"trying to parse filter for input metric `{input_metric.name}` on metric `{metric.name}`", @@ -170,7 +168,7 @@ def test_metric_where_filter_validations_invalid_granularity( # noqa: D PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'MONTH') }}"), ] ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) issues = validator.validate_semantic_manifest(manifest) assert not issues.has_blocking_issues assert len(issues.warnings) == 1 @@ -203,7 +201,7 @@ def test_saved_query_with_happy_filter( # noqa: D ), ] - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) check_no_errors_or_warnings(manifest_validator.validate_semantic_manifest(manifest)) @@ -228,7 +226,7 @@ def test_saved_query_validates_granularity_name_despite_case( # noqa: D ), ] - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) check_no_errors_or_warnings(manifest_validator.validate_semantic_manifest(manifest)) @@ -250,7 +248,7 @@ def test_invalid_where_in_saved_query( # noqa: D ), ] - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) check_only_one_error_with_message( manifest_validator.validate_semantic_manifest(manifest), "trying to parse a filter in saved query", @@ -278,7 +276,7 @@ def test_saved_query_where_filter_validations_invalid_granularity( # noqa: D ), ] - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) check_only_one_warning_with_message( manifest_validator.validate_semantic_manifest(manifest), "is not a valid granularity name", @@ -302,7 +300,7 @@ def test_metric_filter_error( # noqa: D ), ] - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseableRule()]) + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) check_only_one_error_with_message( manifest_validator.validate_semantic_manifest(manifest), "An error occurred while trying to parse a filter in saved query", From 80f352fcaa8cfabb99c6b65decc329fddebe3ed1 Mon Sep 17 00:00:00 2001 From: Patrick Yost Date: Fri, 1 Nov 2024 14:22:16 -0700 Subject: [PATCH 6/8] Import WhereFiltersAreParseable to metrics.py to avoid making breaking change. --- dbt_semantic_interfaces/validations/metrics.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dbt_semantic_interfaces/validations/metrics.py b/dbt_semantic_interfaces/validations/metrics.py index da9153a8..80278a51 100644 --- a/dbt_semantic_interfaces/validations/metrics.py +++ b/dbt_semantic_interfaces/validations/metrics.py @@ -37,6 +37,11 @@ validate_safely, ) +# Avoids breaking change from moving this class out of this file. +from dbt_semantic_interfaces.validations.where_filters import ( + WhereFiltersAreParseable, # noQa +) + class CumulativeMetricRule(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): """Checks that cumulative metrics are configured properly.""" From 40bd63c5f02c609bb07bdd73fd920a911166dd72 Mon Sep 17 00:00:00 2001 From: Patrick Yost Date: Fri, 1 Nov 2024 14:27:58 -0700 Subject: [PATCH 7/8] Update SemanticManifestNodeType for compatibility with older python versions --- dbt_semantic_interfaces/validations/where_filters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dbt_semantic_interfaces/validations/where_filters.py b/dbt_semantic_interfaces/validations/where_filters.py index 01ba5364..57ebaa30 100644 --- a/dbt_semantic_interfaces/validations/where_filters.py +++ b/dbt_semantic_interfaces/validations/where_filters.py @@ -1,5 +1,5 @@ import traceback -from enum import StrEnum, auto +from enum import Enum from typing import Generic, List, Sequence, Tuple from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets @@ -21,11 +21,11 @@ ) -class SemanticManifestNodeType(StrEnum): +class SemanticManifestNodeType(Enum): """Types of objects to validate (used for validation messages).""" - SAVED_QUERY = auto() - METRIC = auto() + SAVED_QUERY = "saved query" + METRIC = "metric" class WhereFiltersAreParseable(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): From 5083005d95bff2587a88b60f818367f197b74220 Mon Sep 17 00:00:00 2001 From: William Deng <33618746+WilliamDee@users.noreply.github.com> Date: Mon, 11 Nov 2024 18:19:23 -0500 Subject: [PATCH 8/8] Support custom grain in DSI callsites (#363) ### Description Given that custom granularity is available, we need to start supporting it when parsing through names and filters. ### Non-breaking changes - Removed `WhereFilterTimeDimensionFactory` as it's not used anywhere (MetricFlow has it's own in [metricflow-semantics](https://github.com/dbt-labs/metricflow/blob/main/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_time_dimension.py#L56)) - Added tests to test parsing of group by's using custom grain and filtering with custom grain ### Breaking changes - Removed `DunderedNameFormatter` in favour of `StructuredDunderedName` as it's duplicate logic - Added `custom_granularity_names` to `StructuredDunderedName.parse_name` as a parameter to parse out any custom grain - Updated callsites for this - Changed `PydanticWhereFilter.call_parameter_sets` and `PydanticWhereFilterIntersection.filter_expression_parameter_sets` from a property to a method that takes in the valid custom grain names - Updated callsites for this ### Checklist - [x] I have read [the contributing guide](https://github.com/dbt-labs/dbt-semantic-interfaces/blob/main/CONTRIBUTING.md) and understand what's expected of me - [x] I have signed the [CLA](https://docs.getdbt.com/docs/contributor-license-agreements) - [x] This PR includes tests, or tests are not required/relevant for this PR - [x] I have run `changie new` to [create a changelog entry](https://github.com/dbt-labs/dbt-semantic-interfaces/blob/main/CONTRIBUTING.md#adding-a-changelog-entry) Resolves SL-2989 --- .../Breaking Changes-20241105-180727.yaml | 6 ++ .../implementations/filters/where_filter.py | 21 +++-- dbt_semantic_interfaces/naming/dundered.py | 77 ++++--------------- .../parsing/text_input/ti_description.py | 7 +- .../parsing/text_input/ti_processor.py | 46 ----------- .../where_filter/parameter_set_factory.py | 15 ++-- .../where_filter/where_filter_parser.py | 6 +- .../where_filter_time_dimension.py | 57 -------------- .../protocols/where_filter.py | 8 +- .../validations/saved_query.py | 14 +++- .../validations/where_filters.py | 40 +++++++--- tests/example_project_configuration.py | 8 ++ .../project_configuration.yaml | 8 ++ .../where_filter/test_parse_calls.py | 46 +++++++---- tests/parsing/test_saved_query_parsing.py | 29 +++++++ tests/parsing/test_where_filter_parsing.py | 30 ++++++-- .../test_where_filters_are_parseable.py | 2 + 17 files changed, 198 insertions(+), 222 deletions(-) create mode 100644 .changes/unreleased/Breaking Changes-20241105-180727.yaml delete mode 100644 dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py diff --git a/.changes/unreleased/Breaking Changes-20241105-180727.yaml b/.changes/unreleased/Breaking Changes-20241105-180727.yaml new file mode 100644 index 00000000..b6f58a59 --- /dev/null +++ b/.changes/unreleased/Breaking Changes-20241105-180727.yaml @@ -0,0 +1,6 @@ +kind: Breaking Changes +body: Update PydanticWhereFilter.call_parameter_sets and PydanticWhereFilterIntersection.filter_expression_parameter_sets from property to a method +time: 2024-11-05T18:07:27.325103-05:00 +custom: + Author: WilliamDee + Issue: None diff --git a/dbt_semantic_interfaces/implementations/filters/where_filter.py b/dbt_semantic_interfaces/implementations/filters/where_filter.py index 96a74581..f49e237f 100644 --- a/dbt_semantic_interfaces/implementations/filters/where_filter.py +++ b/dbt_semantic_interfaces/implementations/filters/where_filter.py @@ -2,7 +2,7 @@ import textwrap import traceback -from typing import Callable, Generator, List, Tuple +from typing import Callable, Generator, List, Sequence, Tuple from typing_extensions import Self @@ -49,9 +49,10 @@ def _from_yaml_value( else: raise ValueError(f"Expected input to be of type string, but got type {type(input)} with value: {input}") - @property - def call_parameter_sets(self) -> FilterCallParameterSets: # noqa: D - return WhereFilterParser.parse_call_parameter_sets(self.where_sql_template) + def call_parameter_sets(self, custom_granularity_names: Sequence[str]) -> FilterCallParameterSets: # noqa: D + return WhereFilterParser.parse_call_parameter_sets( + where_sql_template=self.where_sql_template, custom_granularity_names=custom_granularity_names + ) class PydanticWhereFilterIntersection(HashableBaseModel): @@ -115,14 +116,20 @@ def _convert_legacy_and_yaml_input(cls, input: PydanticParseableValueType) -> Se f"or dict but got {type(input)} with value {input}" ) - @property - def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParameterSets]]: + def filter_expression_parameter_sets( + self, custom_granularity_names: Sequence[str] + ) -> List[Tuple[str, FilterCallParameterSets]]: """Gets the call parameter sets for each filter expression.""" filter_parameter_sets: List[Tuple[str, FilterCallParameterSets]] = [] invalid_filter_expressions: List[Tuple[str, Exception]] = [] for where_filter in self.where_filters: try: - filter_parameter_sets.append((where_filter.where_sql_template, where_filter.call_parameter_sets)) + filter_parameter_sets.append( + ( + where_filter.where_sql_template, + where_filter.call_parameter_sets(custom_granularity_names=custom_granularity_names), + ) + ) except Exception as e: invalid_filter_expressions.append((where_filter.where_sql_template, e)) diff --git a/dbt_semantic_interfaces/naming/dundered.py b/dbt_semantic_interfaces/naming/dundered.py index 1095d3a3..d8c03e47 100644 --- a/dbt_semantic_interfaces/naming/dundered.py +++ b/dbt_semantic_interfaces/naming/dundered.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Sequence, Tuple from dbt_semantic_interfaces.naming.keywords import DUNDER from dbt_semantic_interfaces.references import EntityReference @@ -19,20 +19,14 @@ class StructuredDunderedName: entity_links: ["listing"] element_name: "ds" granularity: TimeGranularity.WEEK - - The time granularity is part of legacy query syntax and there are plans to migrate away from this format. As such, - this will not be updated to allow for custom granularity values. This implies that any query paths that push named - parameters through this class will not support a custom grain reference of the form `metric_time__martian_year`, - and users wishing to use their martian year grain will have to explicitly reference it via a separate parameter - instead of gluing it onto the end of the name. """ entity_links: Tuple[EntityReference, ...] element_name: str - time_granularity: Optional[TimeGranularity] = None + time_granularity: Optional[str] = None @staticmethod - def parse_name(name: str) -> StructuredDunderedName: + def parse_name(name: str, custom_granularity_names: Sequence[str] = ()) -> StructuredDunderedName: """Construct from a string like 'listing__ds__month'.""" name_parts = name.split(DUNDER) @@ -40,11 +34,17 @@ def parse_name(name: str) -> StructuredDunderedName: if len(name_parts) == 1: return StructuredDunderedName((), name_parts[0]) - associated_granularity = None - granularity: TimeGranularity + associated_granularity: Optional[str] = None for granularity in TimeGranularity: if name_parts[-1] == granularity.value: - associated_granularity = granularity + associated_granularity = granularity.value + break + + if associated_granularity is None: + for custom_grain in custom_granularity_names: + if name_parts[-1] == custom_grain: + associated_granularity = custom_grain + break # Has a time granularity if associated_granularity: @@ -69,7 +69,7 @@ def dundered_name(self) -> str: """Return the full name form. e.g. ds or listing__ds__month.""" items = [entity_reference.element_name for entity_reference in self.entity_links] + [self.element_name] if self.time_granularity: - items.append(self.time_granularity.value) + items.append(self.time_granularity) return DUNDER.join(items) @property @@ -82,7 +82,7 @@ def dundered_name_without_granularity(self) -> str: @property def dundered_name_without_entity(self) -> str: """Return the name without the entity. e.g. listing__ds__month -> ds__month.""" - return DUNDER.join((self.element_name,) + ((self.time_granularity.value,) if self.time_granularity else ())) + return DUNDER.join((self.element_name,) + ((self.time_granularity,) if self.time_granularity else ())) @property def entity_prefix(self) -> Optional[str]: @@ -91,52 +91,3 @@ def entity_prefix(self) -> Optional[str]: return DUNDER.join(tuple(entity_reference.element_name for entity_reference in self.entity_links)) return None - - -class DunderedNameFormatter: - """Helps to parse names into StructuredDunderedName and vice versa.""" - - @staticmethod - def parse_name(name: str) -> StructuredDunderedName: - """Construct from a string like 'listing__ds__month'.""" - name_parts = name.split(DUNDER) - - # No dunder, e.g. "ds" - if len(name_parts) == 1: - return StructuredDunderedName((), name_parts[0]) - - associated_granularity = None - granularity: TimeGranularity - for granularity in TimeGranularity: - if name_parts[-1] == granularity.value: - associated_granularity = granularity - - # Has a time granularity - if associated_granularity: - # e.g. "ds__month" - if len(name_parts) == 2: - return StructuredDunderedName((), name_parts[0], associated_granularity) - # e.g. "messages__ds__month" - return StructuredDunderedName( - entity_links=tuple(EntityReference(element_name=entity_name) for entity_name in name_parts[:-2]), - element_name=name_parts[-2], - time_granularity=associated_granularity, - ) - # e.g. "messages__ds" - else: - return StructuredDunderedName( - entity_links=tuple(EntityReference(element_name=entity_name) for entity_name in name_parts[:-1]), - element_name=name_parts[-1], - ) - - @staticmethod - def create_structured_name( # noqa: D - element_name: str, - entity_links: Tuple[EntityReference, ...] = (), - time_granularity: Optional[TimeGranularity] = None, - ) -> StructuredDunderedName: - return StructuredDunderedName( - entity_links=entity_links, - element_name=element_name, - time_granularity=time_granularity, - ) diff --git a/dbt_semantic_interfaces/parsing/text_input/ti_description.py b/dbt_semantic_interfaces/parsing/text_input/ti_description.py index 62fff662..cdb619aa 100644 --- a/dbt_semantic_interfaces/parsing/text_input/ti_description.py +++ b/dbt_semantic_interfaces/parsing/text_input/ti_description.py @@ -56,13 +56,14 @@ def __post_init__(self) -> None: # noqa: D105 else: assert_values_exhausted(item_type) - structured_item_name = StructuredDunderedName.parse_name(self.item_name) - # Check that metrics do not have an entity prefix or entity path. if item_type is QueryItemType.METRIC: if len(self.entity_path) > 0: raise InvalidQuerySyntax("The entity path should not be specified for a metric.") - if len(structured_item_name.entity_links) > 0: + if ( + len(StructuredDunderedName.parse_name(name=self.item_name, custom_granularity_names=()).entity_links) + > 0 + ): raise InvalidQuerySyntax("The name of the metric should not have entity links.") # Check that dimensions / time dimensions have a valid date part. elif item_type is QueryItemType.DIMENSION or item_type is QueryItemType.TIME_DIMENSION: diff --git a/dbt_semantic_interfaces/parsing/text_input/ti_processor.py b/dbt_semantic_interfaces/parsing/text_input/ti_processor.py index 6823a6a4..cac7c122 100644 --- a/dbt_semantic_interfaces/parsing/text_input/ti_processor.py +++ b/dbt_semantic_interfaces/parsing/text_input/ti_processor.py @@ -10,9 +10,6 @@ from typing_extensions import override from dbt_semantic_interfaces.errors import InvalidQuerySyntax -from dbt_semantic_interfaces.parsing.text_input.description_renderer import ( - QueryItemDescriptionRenderer, -) from dbt_semantic_interfaces.parsing.text_input.rendering_helper import ( ObjectBuilderJinjaRenderHelper, ) @@ -77,34 +74,6 @@ def collect_descriptions_from_template( ) return description_collector.collected_descriptions() - def render_template( - self, - jinja_template: str, - renderer: QueryItemDescriptionRenderer, - valid_method_mapping: ValidMethodMapping, - ) -> str: - """Renders the Jinja template using the specified renderer. - - Args: - jinja_template: A Jinja template string like `{{ Dimension('listing__country') }} = 'US'`. - renderer: The renderer to use for rendering each item. - valid_method_mapping: Mapping from the builder object to the valid methods. See - `ConfiguredValidMethodMapping`. - - Returns: - The rendered Jinja template. - - Raises: - QueryItemJinjaException: See definition. - InvalidBuilderMethodException: See definition. - """ - render_processor = _RendererProcessor(renderer) - return self._process_template( - jinja_template=jinja_template, - valid_method_mapping=valid_method_mapping, - description_processor=render_processor, - ) - def _process_template( self, jinja_template: str, @@ -161,18 +130,3 @@ def process_description(self, item_description: ObjectBuilderItemDescription) -> self._items.append(item_description) return "" - - -class _RendererProcessor(ObjectBuilderItemDescriptionProcessor): - """Processor that renders the descriptions in a Jinja template using the given renderer. - - This is just a pass-through, but it allows `QueryItemDescriptionRenderer` to be a facade that has more appropriate - method names. - """ - - def __init__(self, renderer: QueryItemDescriptionRenderer) -> None: # noqa: D107 - self._renderer = renderer - - @override - def process_description(self, item_description: ObjectBuilderItemDescription) -> str: - return self._renderer.render_description(item_description) diff --git a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py index c59dc016..565ddaad 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py +++ b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py @@ -7,7 +7,7 @@ ParseWhereFilterException, TimeDimensionCallParameterSet, ) -from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter +from dbt_semantic_interfaces.naming.dundered import StructuredDunderedName from dbt_semantic_interfaces.naming.keywords import is_metric_time_name from dbt_semantic_interfaces.references import ( DimensionReference, @@ -46,6 +46,7 @@ def _exception_message_for_incorrect_format(element_name: str) -> str: @staticmethod def create_time_dimension( time_dimension_name: str, + custom_granularity_names: Sequence[str], time_granularity_name: Optional[str] = None, entity_path: Sequence[str] = (), date_part_name: Optional[str] = None, @@ -65,14 +66,14 @@ def create_time_dimension( for parsing where filters. When we solve the problems with our current where filter spec this will persist as a backwards compatibility model, but nothing more. """ - group_by_item_name = DunderedNameFormatter.parse_name(time_dimension_name) + group_by_item_name = StructuredDunderedName.parse_name( + name=time_dimension_name, custom_granularity_names=custom_granularity_names + ) if len(group_by_item_name.entity_links) != 1 and not is_metric_time_name(group_by_item_name.element_name): raise ParseWhereFilterException( ParameterSetFactory._exception_message_for_incorrect_format(time_dimension_name) ) - grain_parsed_from_name = ( - group_by_item_name.time_granularity.value if group_by_item_name.time_granularity else None - ) + grain_parsed_from_name = group_by_item_name.time_granularity inputs_are_mismatched = ( grain_parsed_from_name is not None and time_granularity_name is not None @@ -101,7 +102,7 @@ def create_time_dimension( @staticmethod def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> DimensionCallParameterSet: """Gets called by Jinja when rendering {{ Dimension(...) }}.""" - group_by_item_name = DunderedNameFormatter.parse_name(dimension_name) + group_by_item_name = StructuredDunderedName.parse_name(name=dimension_name, custom_granularity_names=()) if len(group_by_item_name.entity_links) != 1 and not is_metric_time_name(group_by_item_name.element_name): raise ParseWhereFilterException(ParameterSetFactory._exception_message_for_incorrect_format(dimension_name)) @@ -116,7 +117,7 @@ def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> Di @staticmethod def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCallParameterSet: """Gets called by Jinja when rendering {{ Entity(...) }}.""" - structured_dundered_name = DunderedNameFormatter.parse_name(entity_name) + structured_dundered_name = StructuredDunderedName.parse_name(name=entity_name, custom_granularity_names=()) if structured_dundered_name.time_granularity is not None: raise ParseWhereFilterException( f"Name is in an incorrect format: {repr(entity_name)}. " f"It should not contain a time grain suffix." diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py index 7d070b84..f8ba9a51 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py +++ b/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py @@ -39,7 +39,9 @@ def parse_item_descriptions(where_sql_template: str) -> Sequence[ObjectBuilderIt raise ParseWhereFilterException(f"Error while parsing Jinja template:\n{where_sql_template}") from e @staticmethod - def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSets: + def parse_call_parameter_sets( + where_sql_template: str, custom_granularity_names: Sequence[str] + ) -> FilterCallParameterSets: """Return the result of extracting the semantic objects referenced in the where SQL template string.""" descriptions = WhereFilterParser.parse_item_descriptions(where_sql_template) @@ -63,6 +65,7 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet time_granularity_name=description.time_granularity_name, entity_path=description.entity_path, date_part_name=description.date_part_name, + custom_granularity_names=custom_granularity_names, ) ) else: @@ -79,6 +82,7 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet time_granularity_name=description.time_granularity_name, entity_path=description.entity_path, date_part_name=description.date_part_name, + custom_granularity_names=custom_granularity_names, ) ) elif item_type is QueryItemType.ENTITY: diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py deleted file mode 100644 index 693c8344..00000000 --- a/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -from typing import List, Optional, Sequence - -from typing_extensions import override - -from dbt_semantic_interfaces.call_parameter_sets import TimeDimensionCallParameterSet -from dbt_semantic_interfaces.errors import InvalidQuerySyntax -from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( - ParameterSetFactory, -) -from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint -from dbt_semantic_interfaces.protocols.query_interface import ( - QueryInterfaceTimeDimension, - QueryInterfaceTimeDimensionFactory, -) - - -class TimeDimensionStub(ProtocolHint[QueryInterfaceTimeDimension]): - """A TimeDimension implementation that just satisfies the protocol. - - QueryInterfaceTimeDimension currently has no methods and the parameter set is created in the factory. - So, there is nothing to do here. - """ - - @override - def _implements_protocol(self) -> QueryInterfaceTimeDimension: - return self - - -class WhereFilterTimeDimensionFactory(ProtocolHint[QueryInterfaceTimeDimensionFactory]): - """Executes in the Jinja sandbox to produce parameter sets and append them to a list.""" - - @override - def _implements_protocol(self) -> QueryInterfaceTimeDimensionFactory: - return self - - def __init__(self) -> None: # noqa - self.time_dimension_call_parameter_sets: List[TimeDimensionCallParameterSet] = [] - - def create( - self, - time_dimension_name: str, - time_granularity_name: Optional[str] = None, - entity_path: Sequence[str] = (), - descending: Optional[bool] = None, - date_part_name: Optional[str] = None, - ) -> TimeDimensionStub: - """Gets called by Jinja when rendering {{ TimeDimension(...) }}.""" - if descending is not None: - raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec") - self.time_dimension_call_parameter_sets.append( - ParameterSetFactory.create_time_dimension( - time_dimension_name, time_granularity_name, entity_path, date_part_name - ) - ) - return TimeDimensionStub() diff --git a/dbt_semantic_interfaces/protocols/where_filter.py b/dbt_semantic_interfaces/protocols/where_filter.py index 7792e006..470e912e 100644 --- a/dbt_semantic_interfaces/protocols/where_filter.py +++ b/dbt_semantic_interfaces/protocols/where_filter.py @@ -13,9 +13,8 @@ def where_sql_template(self) -> str: """A template that describes how to render the SQL for a WHERE clause.""" pass - @property @abstractmethod - def call_parameter_sets(self) -> FilterCallParameterSets: + def call_parameter_sets(self, custom_granularity_names: Sequence[str]) -> FilterCallParameterSets: """Describe calls like 'dimension(...)' in the SQL template.""" pass @@ -41,9 +40,10 @@ def where_filters(self) -> Sequence[WhereFilter]: """The collection of WhereFilters to be applied to the input data set.""" pass - @property @abstractmethod - def filter_expression_parameter_sets(self) -> Sequence[Tuple[str, FilterCallParameterSets]]: + def filter_expression_parameter_sets( + self, custom_granularity_names: Sequence[str] + ) -> Sequence[Tuple[str, FilterCallParameterSets]]: """Mapping from distinct filter expressions to the call parameter sets associated with them. We use a tuple, rather than a Mapping, in case the call parameter sets may vary between diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index c805aa8f..cad7562c 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -49,14 +49,18 @@ class SavedQueryRule(SemanticManifestValidationRule[SemanticManifestT], Generic[ @staticmethod @validate_safely("Validate the group-by field in a saved query.") - def _check_group_bys(valid_group_by_element_names: Set[str], saved_query: SavedQuery) -> Sequence[ValidationIssue]: + def _check_group_bys( + valid_group_by_element_names: Set[str], saved_query: SavedQuery, custom_granularity_names: Sequence[str] + ) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] for group_by_item in saved_query.query_params.group_by: # TODO: Replace with more appropriate abstractions once available. parameter_sets: FilterCallParameterSets try: - parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{" + group_by_item + "}}") + parameter_sets = WhereFilterParser.parse_call_parameter_sets( + where_sql_template="{{" + group_by_item + "}}", custom_granularity_names=custom_granularity_names + ) except Exception as e: issues.append( generate_exception_issue( @@ -245,6 +249,11 @@ def _check_limit(saved_query: SavedQuery) -> Sequence[ValidationIssue]: @validate_safely("Validate all saved queries in a semantic manifest.") def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D issues: List[ValidationIssue] = [] + custom_granularity_names = [ + granularity.name + for time_spine in semantic_manifest.project_configuration.time_spines + for granularity in time_spine.custom_granularities + ] valid_metric_names = {metric.name for metric in semantic_manifest.metrics} valid_group_by_element_names = valid_metric_names.union({METRIC_TIME_ELEMENT_NAME}) for semantic_model in semantic_manifest.semantic_models: @@ -261,6 +270,7 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati issues += SavedQueryRule._check_group_bys( valid_group_by_element_names=valid_group_by_element_names, saved_query=saved_query, + custom_granularity_names=custom_granularity_names, ) issues += SavedQueryRule._check_order_by(saved_query) issues += SavedQueryRule._check_limit(saved_query) diff --git a/dbt_semantic_interfaces/validations/where_filters.py b/dbt_semantic_interfaces/validations/where_filters.py index 57ebaa30..d01dde39 100644 --- a/dbt_semantic_interfaces/validations/where_filters.py +++ b/dbt_semantic_interfaces/validations/where_filters.py @@ -73,7 +73,9 @@ def _validate_time_granularity_names_for_saved_query( element_type=SavedQueryElementType.WHERE, element_value=where_filter.where_sql_template, ), - filter_call_param_sets=where_filter.call_parameter_sets, + filter_call_param_sets=where_filter.call_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) @@ -104,7 +106,7 @@ def _validate_saved_query(saved_query: SavedQuery, valid_granularity_names: List return issues for where_filter in saved_query.query_params.where.where_filters: try: - where_filter.call_parameter_sets + where_filter.call_parameter_sets(custom_granularity_names=valid_granularity_names) except Exception as e: issues.append( generate_exception_issue( @@ -140,7 +142,7 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ if metric.filter is not None: try: - metric.filter.filter_expression_parameter_sets + metric.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) except Exception as e: issues.append( generate_exception_issue( @@ -155,7 +157,9 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) @@ -163,7 +167,7 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ measure = metric.type_params.measure if measure is not None and measure.filter is not None: try: - measure.filter.filter_expression_parameter_sets + measure.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) except Exception as e: issues.append( generate_exception_issue( @@ -179,14 +183,16 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) numerator = metric.type_params.numerator if numerator is not None and numerator.filter is not None: try: - numerator.filter.filter_expression_parameter_sets + numerator.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) except Exception as e: issues.append( generate_exception_issue( @@ -201,14 +207,18 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) denominator = metric.type_params.denominator if denominator is not None and denominator.filter is not None: try: - denominator.filter.filter_expression_parameter_sets + denominator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ) except Exception as e: issues.append( generate_exception_issue( @@ -223,14 +233,18 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) for input_metric in metric.type_params.metrics or []: if input_metric.filter is not None: try: - input_metric.filter.filter_expression_parameter_sets + input_metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ) except Exception as e: issues.append( generate_exception_issue( @@ -246,7 +260,9 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) return issues diff --git a/tests/example_project_configuration.py b/tests/example_project_configuration.py index 71f407bf..c107165e 100644 --- a/tests/example_project_configuration.py +++ b/tests/example_project_configuration.py @@ -51,6 +51,14 @@ primary_column: name: ds_day time_granularity: day + - node_relation: + schema_name: stuffs + alias: week_time_spine + primary_column: + name: ds + time_granularity: week + custom_granularities: + - name: martian_week """ ), ) diff --git a/tests/fixtures/semantic_manifest_yamls/simple_semantic_manifest/project_configuration.yaml b/tests/fixtures/semantic_manifest_yamls/simple_semantic_manifest/project_configuration.yaml index 80c6f34a..efce1c6a 100644 --- a/tests/fixtures/semantic_manifest_yamls/simple_semantic_manifest/project_configuration.yaml +++ b/tests/fixtures/semantic_manifest_yamls/simple_semantic_manifest/project_configuration.yaml @@ -11,3 +11,11 @@ project_configuration: primary_column: name: ds_day time_granularity: day + - node_relation: + alias: mf_time_spine + schema_name: stufffs + primary_column: + name: ds + time_granularity: day + custom_granularities: + - name: martian_day diff --git a/tests/implementations/where_filter/test_parse_calls.py b/tests/implementations/where_filter/test_parse_calls.py index 2a7f9e89..02f284f5 100644 --- a/tests/implementations/where_filter/test_parse_calls.py +++ b/tests/implementations/where_filter/test_parse_calls.py @@ -34,7 +34,7 @@ def test_extract_dimension_call_parameter_sets() -> None: # noqa: D AND {{ Dimension('user__country', entity_path=['listing']) }} == 'US'\ """ ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=( @@ -61,7 +61,7 @@ def test_extract_dimension_with_grain_call_parameter_sets() -> None: # noqa: D {{ Dimension('metric_time').grain('WEEK') }} > 2023-09-18 """ ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -81,7 +81,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ TimeDimension('user__created_at', 'month', entity_path=['listing']) }} = '2020-01-01'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -100,7 +100,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ TimeDimension('user__created_at__month', entity_path=['listing']) }} = '2020-01-01'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -119,7 +119,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D def test_extract_metric_time_dimension_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template="""{{ TimeDimension('metric_time', 'month') }} = '2020-01-01'""" - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -137,7 +137,7 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ Entity('listing') }} AND {{ Entity('user', entity_path=['listing']) }} == 'TEST_USER_ID'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -157,7 +157,7 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D def test_extract_metric_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template=("{{ Metric('bookings', group_by=['listing']) }} > 2") - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -172,7 +172,7 @@ def test_extract_metric_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template=("{{ Metric('bookings', group_by=['listing', 'metric_time']) }} > 2") - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -186,7 +186,9 @@ def test_extract_metric_call_parameter_sets() -> None: # noqa: D ) with pytest.raises(ParseWhereFilterException): - PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets + PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets( + custom_granularity_names=() + ) def test_invalid_entity_name_error() -> None: @@ -194,7 +196,7 @@ def test_invalid_entity_name_error() -> None: bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('is_food_order__day' )}}") with pytest.raises(ParseWhereFilterException, match="Name is in an incorrect format"): - bad_entity_filter.call_parameter_sets + bad_entity_filter.call_parameter_sets(custom_granularity_names=()) def test_where_filter_interesection_extract_call_parameter_sets() -> None: @@ -209,7 +211,7 @@ def test_where_filter_interesection_extract_call_parameter_sets() -> None: ) filter_intersection = PydanticWhereFilterIntersection(where_filters=[time_filter, entity_filter]) - parse_result = dict(filter_intersection.filter_expression_parameter_sets) + parse_result = dict(filter_intersection.filter_expression_parameter_sets(custom_granularity_names=())) assert parse_result.get(time_filter.where_sql_template) == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -250,7 +252,7 @@ def test_where_filter_intersection_error_collection() -> None: ) with pytest.raises(ParseWhereFilterException) as exc_info: - filter_intersection.filter_expression_parameter_sets + filter_intersection.filter_expression_parameter_sets(custom_granularity_names=()) error_string = str(exc_info.value) # These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find. @@ -261,7 +263,7 @@ def test_where_filter_intersection_error_collection() -> None: def test_time_dimension_without_granularity() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template="{{ TimeDimension('booking__created_at') }} > 2023-09-18" - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -274,3 +276,21 @@ def test_time_dimension_without_granularity() -> None: # noqa: D ), entity_call_parameter_sets=(), ) + + +def test_time_dimension_with_custom_granularity() -> None: # noqa: D + parse_result = PydanticWhereFilter( + where_sql_template="{{ TimeDimension('booking__created_at', 'martian_week') }} > 2023-09-18" + ).call_parameter_sets(custom_granularity_names=("martian_week",)) + + assert parse_result == FilterCallParameterSets( + dimension_call_parameter_sets=(), + time_dimension_call_parameter_sets=( + TimeDimensionCallParameterSet( + entity_path=(EntityReference("booking"),), + time_dimension_reference=TimeDimensionReference(element_name="created_at"), + time_granularity_name="martian_week", + ), + ), + entity_call_parameter_sets=(), + ) diff --git a/tests/parsing/test_saved_query_parsing.py b/tests/parsing/test_saved_query_parsing.py index 20b8f4b3..2bc04d11 100644 --- a/tests/parsing/test_saved_query_parsing.py +++ b/tests/parsing/test_saved_query_parsing.py @@ -120,6 +120,35 @@ def test_saved_query_group_by() -> None: ) +def test_saved_query_group_by_with_custom_grain() -> None: + """Test for parsing group_bys in a saved query.""" + yaml_contents = textwrap.dedent( + """\ + saved_query: + name: test_saved_query_group_bys + query_params: + metrics: + - test_metric_a + group_by: + - TimeDimension('test_entity__metric_time', 'martian_week') + - Dimension('test_entity__metric_time__martian_week') + + """ + ) + file = YamlConfigFile(filepath="test_dir/inline_for_test", contents=yaml_contents) + + build_result = parse_yaml_files_to_semantic_manifest(files=[file, EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE]) + + assert len(build_result.semantic_manifest.saved_queries) == 1 + saved_query = build_result.semantic_manifest.saved_queries[0] + assert len(saved_query.query_params.group_by) == 2 + print(saved_query.query_params.group_by) + assert { + "TimeDimension('test_entity__metric_time', 'martian_week')", + "Dimension('test_entity__metric_time__martian_week')", + } == set(saved_query.query_params.group_by) + + def test_saved_query_where() -> None: """Test for parsing where clause in a saved query.""" where = "Dimension(test_entity__test_dimension) == true" diff --git a/tests/parsing/test_where_filter_parsing.py b/tests/parsing/test_where_filter_parsing.py index 60a37e0a..8764a97a 100644 --- a/tests/parsing/test_where_filter_parsing.py +++ b/tests/parsing/test_where_filter_parsing.py @@ -165,14 +165,14 @@ def test_where_filter_intersection_from_partially_deserialized_list_of_strings() ], ) def test_time_dimension_date_part(where: str) -> None: # noqa - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR def test_dimension_date_part() -> None: # noqa where = "{{ Dimension('metric_time').grain('DAY').date_part('YEAR') }} > '2023-01-01'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR @@ -196,20 +196,36 @@ def test_dimension_date_part() -> None: # noqa time_granularity_name=TimeGranularity.WEEK.value, ), ), + ( + "{{ TimeDimension('metric_time__martian_week') }} > '2023-01-01'", + TimeDimensionCallParameterSet( + time_dimension_reference=TimeDimensionReference("metric_time"), + entity_path=(), + time_granularity_name="martian_week", + ), + ), + ( + "{{ TimeDimension('metric_time', time_granularity_name='martian_week') }} > '2023-01-01'", + TimeDimensionCallParameterSet( + time_dimension_reference=TimeDimensionReference("metric_time"), + entity_path=(), + time_granularity_name="martian_week", + ), + ), ], ) def test_time_dimension_grain( # noqa where_and_expected_call_params: Tuple[str, Union[TimeDimensionCallParameterSet, DimensionCallParameterSet]] ) -> None: where, expected_call_params = where_and_expected_call_params - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=("martian_week",)) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0] == expected_call_params def test_entity_without_primary_entity_prefix() -> None: # noqa where = "{{ Entity('non_primary_entity') }} = '1'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.entity_call_parameter_sets) == 1 assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( entity_path=(), @@ -219,7 +235,7 @@ def test_entity_without_primary_entity_prefix() -> None: # noqa def test_entity() -> None: # noqa where = "{{ Entity('entity_1__entity_2', entity_path=['entity_0']) }} = '1'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.entity_call_parameter_sets) == 1 assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( entity_path=( @@ -232,7 +248,7 @@ def test_entity() -> None: # noqa def test_metric() -> None: # noqa where = "{{ Metric('metric', group_by=['dimension']) }} = 10" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.metric_call_parameter_sets) == 1 assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet( group_by=(LinkableElementReference(element_name="dimension"),), @@ -241,7 +257,7 @@ def test_metric() -> None: # noqa # Without kwarg syntax where = "{{ Metric('metric', ['dimension']) }} = 10" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.metric_call_parameter_sets) == 1 assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet( group_by=(LinkableElementReference(element_name="dimension"),), diff --git a/tests/validations/test_where_filters_are_parseable.py b/tests/validations/test_where_filters_are_parseable.py index ae948a99..3bedd8ee 100644 --- a/tests/validations/test_where_filters_are_parseable.py +++ b/tests/validations/test_where_filters_are_parseable.py @@ -166,6 +166,7 @@ def test_metric_where_filter_validations_invalid_granularity( # noqa: D PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'month') }}"), PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'MONTH') }}"), + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'martian_day') }}"), ] ) validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) @@ -195,6 +196,7 @@ def test_saved_query_with_happy_filter( # noqa: D where=PydanticWhereFilterIntersection( where_filters=[ PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'hour') }}"), + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'martian_day') }}"), ] ), ),