diff --git a/.changes/unreleased/Features-20231009-210737.yaml b/.changes/unreleased/Features-20231009-210737.yaml new file mode 100644 index 00000000..31b530db --- /dev/null +++ b/.changes/unreleased/Features-20231009-210737.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Allow metric filters and saved query where properties to accept lists of filter + expressions +time: 2023-10-09T21:07:37.978465-07:00 +custom: + Author: tlento + Issue: "147" diff --git a/dbt_semantic_interfaces/implementations/filters/where_filter.py b/dbt_semantic_interfaces/implementations/filters/where_filter.py index 437d7f78..0be414e6 100644 --- a/dbt_semantic_interfaces/implementations/filters/where_filter.py +++ b/dbt_semantic_interfaces/implementations/filters/where_filter.py @@ -1,6 +1,13 @@ from __future__ import annotations -from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets +from typing import Callable, Generator, List, Tuple + +from typing_extensions import Self + +from dbt_semantic_interfaces.call_parameter_sets import ( + FilterCallParameterSets, + ParseWhereFilterException, +) from dbt_semantic_interfaces.implementations.base import ( HashableBaseModel, PydanticCustomInputParser, @@ -9,17 +16,21 @@ from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import ( WhereFilterParser, ) +from dbt_semantic_interfaces.pretty_print import pformat_big_objects class PydanticWhereFilter(PydanticCustomInputParser, HashableBaseModel): - """A filter applied to the data set containing measures, dimensions, identifiers relevant to the query. + """Pydantic implementation of a WhereFilter. - TODO: Clarify whether the filter applies to aggregated or un-aggregated data sets. + This specifies a templated SQl where expression, with templates allowing for extraction of dimensions and + entities (and, eventually, measures and metrics) to include in the filter itself. This filter will then + be applied to an input data set, either from an original input source or an intermediate subquery output. - The data set will contain dimensions as required by the query and the dimensions that a referenced in any of the - filters that are used in the definition of metrics. + The data set will contain entities and dimensions as referenced in the query along with the entities and dimensions + that are referenced in any of these filters, whether they are part of the query request or metric definition. """ + # The where_sql_template field is used in PydanticWhereFilterIntersection.convert_legacy_input. Remove with caution. where_sql_template: str @classmethod @@ -40,3 +51,85 @@ def _from_yaml_value( @property def call_parameter_sets(self) -> FilterCallParameterSets: # noqa: D return WhereFilterParser.parse_call_parameter_sets(self.where_sql_template) + + +class PydanticWhereFilterIntersection(HashableBaseModel): + """Pydantic implementation of a WhereFilterIntersection.""" + + # This class can not have a property named `where_sql_template` without a parsing logic update + __WHERE_SQL_TEMPLATE_FIELD__ = "where_sql_template" + __WHERE_FILTERS_FIELD__ = "where_filters" + + where_filters: List[PydanticWhereFilter] + + @classmethod + def __get_validators__(cls) -> Generator[Callable[[PydanticParseableValueType], Self], None, None]: + """Pydantic magic method for allowing handling of arbitrary input on parse_obj invocation. + + This class requires more subtle handling of input deserialized object types (dicts), and so it cannot + extend the common interface via _from_yaml_values. + """ + yield cls._convert_legacy_and_yaml_input + + @classmethod + def _convert_legacy_and_yaml_input(cls, input: PydanticParseableValueType) -> Self: + """Specifies raw input conversion rules to ensure serialized semantic manifests will parse correctly. + + The original spec for where filters relied on a raw WhereFilter object, but this has now been updated to + expect an object containing a collection of WhereFilters. + + The inputs for the original PydanticWhereFilter could have been either a bare string, a PydanticWhereFilter, + or a partially deserialized json object (i.e., dict) representation of the PydanticWhereFilter. + + Consequently, we must support a variety of inputs and coerce them into the appropriate form, which is in general + a List[valid_where_filter_input] with valid_where_filter_input being one of the types described above. Here + are the operations: + + Sequence transforms: + 1. str -> {"where_filters": [input]} + 2. PydanticWhereFilter -> {"where_filters": [input]} + 3. {"where_sql_template": str} -> {"where_filters": [input]} + + Object initializations (inputs requiring standard initialization, validated via the next pydantic operation): + 1. List -> PydanticWhereFilterIntersection(where_filters=input) + 2. other dicts -> PydanticWhereFilterIntersection(**input) + + Identity transforms (no-ops, as these represent PydanticWhereFilterIntersection objects): + 1. PydanticWhereFilterIntersection + """ + has_legacy_keys = isinstance(input, dict) and cls.__WHERE_SQL_TEMPLATE_FIELD__ in input.keys() + is_legacy_where_filter = isinstance(input, str) or isinstance(input, PydanticWhereFilter) or has_legacy_keys + + if is_legacy_where_filter: + return cls(where_filters=[input]) + elif isinstance(input, list): + return cls(where_filters=input) + elif isinstance(input, dict): + return cls(**input) + elif isinstance(input, cls): + return input + else: + raise ValueError( + f"Expected input to be of type string, list, PydanticWhereFilter, PydanticWhereFilterIntersection, " + f"or dict but got {type(input)} with value {input}" + ) + + @property + def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParameterSets]]: + """Gets the call parameter sets for each filter expression.""" + filter_parameter_sets: List[Tuple[str, FilterCallParameterSets]] = [] + invalid_filter_expressions: List[Tuple[str, Exception]] = [] + for where_filter in self.where_filters: + try: + filter_parameter_sets.append((where_filter.where_sql_template, where_filter.call_parameter_sets)) + except Exception as e: + invalid_filter_expressions.append((where_filter.where_sql_template, e)) + + if invalid_filter_expressions: + raise ParseWhereFilterException( + f"Encountered one or more errors when parsing the set of filter expressions " + f"{pformat_big_objects(self.where_filters)}! Invalid expressions: \n " + f"{pformat_big_objects(invalid_filter_expressions)}" + ) + + return filter_parameter_sets diff --git a/dbt_semantic_interfaces/implementations/metric.py b/dbt_semantic_interfaces/implementations/metric.py index 111d9bb7..656e09c5 100644 --- a/dbt_semantic_interfaces/implementations/metric.py +++ b/dbt_semantic_interfaces/implementations/metric.py @@ -13,7 +13,7 @@ PydanticParseableValueType, ) from dbt_semantic_interfaces.implementations.filters.where_filter import ( - PydanticWhereFilter, + PydanticWhereFilterIntersection, ) from dbt_semantic_interfaces.implementations.metadata import PydanticMetadata from dbt_semantic_interfaces.references import MeasureReference, MetricReference @@ -28,7 +28,7 @@ class PydanticMetricInputMeasure(PydanticCustomInputParser, HashableBaseModel): """ name: str - filter: Optional[PydanticWhereFilter] + filter: Optional[PydanticWhereFilterIntersection] alias: Optional[str] join_to_timespine: bool = False fill_nulls_with: Optional[int] = None @@ -118,7 +118,7 @@ class PydanticMetricInput(HashableBaseModel): """Provides a pointer to a metric along with the additional properties used on that metric.""" name: str - filter: Optional[PydanticWhereFilter] + filter: Optional[PydanticWhereFilterIntersection] alias: Optional[str] offset_window: Optional[PydanticMetricTimeWindow] offset_to_grain: Optional[TimeGranularity] @@ -155,7 +155,7 @@ class PydanticMetric(HashableBaseModel, ModelWithMetadataParsing): description: Optional[str] type: MetricType type_params: PydanticMetricTypeParams - filter: Optional[PydanticWhereFilter] + filter: Optional[PydanticWhereFilterIntersection] metadata: Optional[PydanticMetadata] label: Optional[str] = None diff --git a/dbt_semantic_interfaces/implementations/saved_query.py b/dbt_semantic_interfaces/implementations/saved_query.py index 6ff709b9..53de5038 100644 --- a/dbt_semantic_interfaces/implementations/saved_query.py +++ b/dbt_semantic_interfaces/implementations/saved_query.py @@ -9,7 +9,7 @@ ModelWithMetadataParsing, ) from dbt_semantic_interfaces.implementations.filters.where_filter import ( - PydanticWhereFilter, + PydanticWhereFilterIntersection, ) from dbt_semantic_interfaces.implementations.metadata import PydanticMetadata from dbt_semantic_interfaces.protocols import ProtocolHint @@ -26,7 +26,7 @@ def _implements_protocol(self) -> SavedQuery: name: str metrics: List[str] group_bys: List[str] = [] - where: List[PydanticWhereFilter] = [] + where: Optional[PydanticWhereFilterIntersection] = None description: Optional[str] = None metadata: Optional[PydanticMetadata] = None diff --git a/dbt_semantic_interfaces/parsing/generated_json_schemas/default_explicit_schema.json b/dbt_semantic_interfaces/parsing/generated_json_schemas/default_explicit_schema.json index 784ba5b0..7d0c3253 100644 --- a/dbt_semantic_interfaces/parsing/generated_json_schemas/default_explicit_schema.json +++ b/dbt_semantic_interfaces/parsing/generated_json_schemas/default_explicit_schema.json @@ -139,6 +139,20 @@ ], "type": "object" }, + "filter_schema": { + "$id": "filter_schema", + "oneOf": [ + { + "type": "string" + }, + { + "items": { + "type": "string" + }, + "type": "array" + } + ] + }, "is-time-dimension": { "properties": { "type": { @@ -234,7 +248,7 @@ "type": "integer" }, "filter": { - "type": "string" + "$ref": "#/definitions/filter_schema" }, "join_to_timespine": { "type": "boolean" @@ -255,7 +269,7 @@ "type": "string" }, "filter": { - "type": "string" + "$ref": "#/definitions/filter_schema" }, "name": { "type": "string" @@ -277,7 +291,7 @@ "type": "string" }, "filter": { - "type": "string" + "$ref": "#/definitions/filter_schema" }, "label": { "type": "string" @@ -435,10 +449,7 @@ "type": "string" }, "where": { - "items": { - "type": "string" - }, - "type": "array" + "$ref": "#/definitions/filter_schema" } }, "required": [ diff --git a/dbt_semantic_interfaces/parsing/schemas.py b/dbt_semantic_interfaces/parsing/schemas.py index a0256b07..9144c2b2 100644 --- a/dbt_semantic_interfaces/parsing/schemas.py +++ b/dbt_semantic_interfaces/parsing/schemas.py @@ -39,6 +39,17 @@ time_dimension_type_values = ["TIME", "time"] +filter_schema = { + "$id": "filter_schema", + "oneOf": [ + {"type": "string"}, + { + "type": "array", + "items": {"type": "string"}, + }, + ], +} + metric_input_measure_schema = { "$id": "metric_input_measure_schema", "oneOf": [ @@ -47,7 +58,7 @@ "type": "object", "properties": { "name": {"type": "string"}, - "filter": {"type": "string"}, + "filter": {"$ref": "filter_schema"}, "alias": {"type": "string"}, "join_to_timespine": {"type": "boolean"}, "fill_nulls_with": {"type": "integer"}, @@ -62,7 +73,7 @@ "type": "object", "properties": { "name": {"type": "string"}, - "filter": {"type": "string"}, + "filter": {"$ref": "filter_schema"}, "alias": {"type": "string"}, "offset_window": {"type": "string"}, "offset_to_grain": {"type": "string"}, @@ -218,7 +229,7 @@ }, "type": {"enum": metric_types_enum_values}, "type_params": {"$ref": "metric_type_params"}, - "filter": {"type": "string"}, + "filter": {"$ref": "filter_schema"}, "description": {"type": "string"}, "label": {"type": "string"}, }, @@ -292,10 +303,7 @@ "type": "array", "items": {"type": "string"}, }, - "where": { - "type": "array", - "items": {"type": "string"}, - }, + "where": {"$ref": "filter_schema"}, "label": {"type": "string"}, }, "required": ["name", "metrics"], @@ -333,6 +341,7 @@ project_configuration_schema["$id"]: project_configuration_schema, saved_query_schema["$id"]: saved_query_schema, # Sub-object schemas + filter_schema["$id"]: filter_schema, metric_input_measure_schema["$id"]: metric_input_measure_schema, metric_type_params_schema["$id"]: metric_type_params_schema, entity_schema["$id"]: entity_schema, diff --git a/dbt_semantic_interfaces/protocols/__init__.py b/dbt_semantic_interfaces/protocols/__init__.py index 45f4917e..239aa83b 100644 --- a/dbt_semantic_interfaces/protocols/__init__.py +++ b/dbt_semantic_interfaces/protocols/__init__.py @@ -28,4 +28,7 @@ SemanticModelDefaults, SemanticModelT, ) -from dbt_semantic_interfaces.protocols.where_filter import WhereFilter # noqa:F401 +from dbt_semantic_interfaces.protocols.where_filter import ( # noqa:F401 + WhereFilter, + WhereFilterIntersection, +) diff --git a/dbt_semantic_interfaces/protocols/metric.py b/dbt_semantic_interfaces/protocols/metric.py index 3f09a29d..31f21682 100644 --- a/dbt_semantic_interfaces/protocols/metric.py +++ b/dbt_semantic_interfaces/protocols/metric.py @@ -4,7 +4,7 @@ from typing import Optional, Protocol, Sequence from dbt_semantic_interfaces.protocols.metadata import Metadata -from dbt_semantic_interfaces.protocols.where_filter import WhereFilter +from dbt_semantic_interfaces.protocols.where_filter import WhereFilterIntersection from dbt_semantic_interfaces.references import MeasureReference, MetricReference from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity @@ -23,7 +23,8 @@ def name(self) -> str: # noqa: D @property @abstractmethod - def filter(self) -> Optional[WhereFilter]: # noqa: D + def filter(self) -> Optional[WhereFilterIntersection]: + """Return the set of filters to apply prior to aggregating this input measure.""" pass @property @@ -80,7 +81,8 @@ def name(self) -> str: # noqa: D @property @abstractmethod - def filter(self) -> Optional[WhereFilter]: # noqa: D + def filter(self) -> Optional[WhereFilterIntersection]: + """Return the set of filters to apply prior to calculating this input metric.""" pass @property @@ -181,7 +183,8 @@ def type_params(self) -> MetricTypeParams: # noqa: D @property @abstractmethod - def filter(self) -> Optional[WhereFilter]: # noqa: D + def filter(self) -> Optional[WhereFilterIntersection]: + """Return the set of filters to apply prior to calculating this metric.""" pass @property diff --git a/dbt_semantic_interfaces/protocols/saved_query.py b/dbt_semantic_interfaces/protocols/saved_query.py index 2018b164..3bd739d9 100644 --- a/dbt_semantic_interfaces/protocols/saved_query.py +++ b/dbt_semantic_interfaces/protocols/saved_query.py @@ -2,7 +2,7 @@ from typing import Optional, Protocol, Sequence from dbt_semantic_interfaces.protocols.metadata import Metadata -from dbt_semantic_interfaces.protocols.where_filter import WhereFilter +from dbt_semantic_interfaces.protocols.where_filter import WhereFilterIntersection class SavedQuery(Protocol): @@ -35,7 +35,8 @@ def group_bys(self) -> Sequence[str]: # noqa: D @property @abstractmethod - def where(self) -> Sequence[WhereFilter]: # noqa: D + def where(self) -> Optional[WhereFilterIntersection]: + """Returns the intersection class containing any where filters specified in the saved query.""" pass @property diff --git a/dbt_semantic_interfaces/protocols/where_filter.py b/dbt_semantic_interfaces/protocols/where_filter.py index 5f9b5642..7792e006 100644 --- a/dbt_semantic_interfaces/protocols/where_filter.py +++ b/dbt_semantic_interfaces/protocols/where_filter.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Protocol +from typing import Protocol, Sequence, Tuple from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets @@ -18,3 +18,35 @@ def where_sql_template(self) -> str: def call_parameter_sets(self) -> FilterCallParameterSets: """Describe calls like 'dimension(...)' in the SQL template.""" pass + + +class WhereFilterIntersection(Protocol): + """A collection of filters to be applied to an input dataset. + + This is an intersection, meaning each input row must pass all filters to be included in the output. It is the + equivalent of using an " AND " expression to join each filter expression in the input set into a single SQL + statement. + + Although there is no formal contract around this, the expectation is these filters will be applied in a manner + that will produce output equivalent to running the WHERE clause, after dimensional joins but before measure + aggregations. + + We use a protocol class here, instead of a simple Sequence, partly to centralize any custom parsing and processing + logic and partly because it is more descriptive as to the relationship between the filter elements in the set. + """ + + @property + @abstractmethod + def where_filters(self) -> Sequence[WhereFilter]: + """The collection of WhereFilters to be applied to the input data set.""" + pass + + @property + @abstractmethod + def filter_expression_parameter_sets(self) -> Sequence[Tuple[str, FilterCallParameterSets]]: + """Mapping from distinct filter expressions to the call parameter sets associated with them. + + We use a tuple, rather than a Mapping, in case the call parameter sets may vary between + filter expression specifications. + """ + pass diff --git a/dbt_semantic_interfaces/test_utils.py b/dbt_semantic_interfaces/test_utils.py index 9051bcc2..addd1b93 100644 --- a/dbt_semantic_interfaces/test_utils.py +++ b/dbt_semantic_interfaces/test_utils.py @@ -8,9 +8,6 @@ from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimension from dbt_semantic_interfaces.implementations.elements.entity import PydanticEntity from dbt_semantic_interfaces.implementations.elements.measure import PydanticMeasure -from dbt_semantic_interfaces.implementations.filters.where_filter import ( - PydanticWhereFilter, -) from dbt_semantic_interfaces.implementations.metadata import ( PydanticFileSlice, PydanticMetadata, @@ -124,7 +121,6 @@ def metric_with_guaranteed_meta( name: str, type: MetricType, type_params: PydanticMetricTypeParams, - where_filter: Optional[PydanticWhereFilter] = None, metadata: PydanticMetadata = default_meta(), description: str = "adhoc metric", ) -> PydanticMetric: @@ -137,7 +133,7 @@ def metric_with_guaranteed_meta( description=description, type=type, type_params=type_params, - filter=where_filter, + filter=None, metadata=metadata, ) diff --git a/dbt_semantic_interfaces/validations/metrics.py b/dbt_semantic_interfaces/validations/metrics.py index ffe97cba..6c213de9 100644 --- a/dbt_semantic_interfaces/validations/metrics.py +++ b/dbt_semantic_interfaces/validations/metrics.py @@ -172,7 +172,7 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D if metric.filter is not None: try: - metric.filter.call_parameter_sets + metric.filter.filter_expression_parameter_sets except Exception as e: issues.append( generate_exception_issue( @@ -181,7 +181,6 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D context=context, extras={ "traceback": "".join(traceback.format_tb(e.__traceback__)), - "filter": metric.filter.where_sql_template, }, ) ) @@ -190,7 +189,7 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D measure = metric.type_params.measure if measure is not None and measure.filter is not None: try: - measure.filter.call_parameter_sets + measure.filter.filter_expression_parameter_sets except Exception as e: issues.append( generate_exception_issue( @@ -200,7 +199,6 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D context=context, extras={ "traceback": "".join(traceback.format_tb(e.__traceback__)), - "filter": measure.filter.where_sql_template, }, ) ) @@ -208,7 +206,7 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D numerator = metric.type_params.numerator if numerator is not None and numerator.filter is not None: try: - numerator.filter.call_parameter_sets + numerator.filter.filter_expression_parameter_sets except Exception as e: issues.append( generate_exception_issue( @@ -217,7 +215,6 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D context=context, extras={ "traceback": "".join(traceback.format_tb(e.__traceback__)), - "filter": numerator.filter.where_sql_template, }, ) ) @@ -225,7 +222,7 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D denominator = metric.type_params.denominator if denominator is not None and denominator.filter is not None: try: - denominator.filter.call_parameter_sets + denominator.filter.filter_expression_parameter_sets except Exception as e: issues.append( generate_exception_issue( @@ -234,7 +231,6 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D context=context, extras={ "traceback": "".join(traceback.format_tb(e.__traceback__)), - "filter": denominator.filter.where_sql_template, }, ) ) @@ -242,7 +238,7 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D for input_metric in metric.type_params.metrics or []: if input_metric.filter is not None: try: - input_metric.filter.call_parameter_sets + input_metric.filter.filter_expression_parameter_sets except Exception as e: issues.append( generate_exception_issue( @@ -252,7 +248,6 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D context=context, extras={ "traceback": "".join(traceback.format_tb(e.__traceback__)), - "filter": input_metric.filter.where_sql_template, }, ) ) diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index 0b2ecd4c..f9abd7f4 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -101,7 +101,9 @@ def _check_metrics(valid_metric_names: Set[str], saved_query: SavedQuery) -> Seq @validate_safely("Validate the where field in a saved query.") def _check_where(saved_query: SavedQuery) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] - for where_filter in saved_query.where: + if saved_query.where is None: + return issues + for where_filter in saved_query.where.where_filters: try: where_filter.call_parameter_sets except Exception as e: diff --git a/tests/implementations/where_filter/test_parse_calls.py b/tests/implementations/where_filter/test_parse_calls.py index 4ad77c5b..f3e069b8 100644 --- a/tests/implementations/where_filter/test_parse_calls.py +++ b/tests/implementations/where_filter/test_parse_calls.py @@ -11,6 +11,10 @@ ) from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, + PydanticWhereFilterIntersection, +) +from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( + ParameterSetFactory, ) from dbt_semantic_interfaces.references import ( DimensionReference, @@ -145,3 +149,69 @@ def test_invalid_entity_name_error() -> None: with pytest.raises(ParseWhereFilterException, match="Entity name is in an incorrect format"): bad_entity_filter.call_parameter_sets + + +def test_where_filter_interesection_extract_call_parameter_sets() -> None: + """Tests the collection of call parameter sets for a set of where filters.""" + time_filter = PydanticWhereFilter( + where_sql_template=("""{{ TimeDimension('metric_time', 'month') }} = '2020-01-01'""") + ) + entity_filter = PydanticWhereFilter( + where_sql_template=( + """{{ Entity('listing') }} AND {{ Entity('user', entity_path=['listing']) }} == 'TEST_USER_ID'""" + ) + ) + filter_intersection = PydanticWhereFilterIntersection(where_filters=[time_filter, entity_filter]) + + parse_result = dict(filter_intersection.filter_expression_parameter_sets) + + assert parse_result.get(time_filter.where_sql_template) == FilterCallParameterSets( + time_dimension_call_parameter_sets=( + TimeDimensionCallParameterSet( + time_dimension_reference=TimeDimensionReference(element_name="metric_time"), + entity_path=(), + time_granularity=TimeGranularity.MONTH, + ), + ) + ) + assert parse_result.get(entity_filter.where_sql_template) == FilterCallParameterSets( + dimension_call_parameter_sets=(), + entity_call_parameter_sets=( + EntityCallParameterSet( + entity_path=(), + entity_reference=EntityReference("listing"), + ), + EntityCallParameterSet( + entity_path=(EntityReference("listing"),), + entity_reference=EntityReference("user"), + ), + ), + ) + + +def test_where_filter_intersection_error_collection() -> None: + """Tests the error behaviors when parsing where filters and collecting the call parameter sets for each. + + This should result in a single exception with all broken filters represented. + """ + metric_time_in_dimension_error = PydanticWhereFilter( + where_sql_template="{{ TimeDimension('order_id__order_time__month', 'week') }} > '2020-01-01'" + ) + valid_dimension = PydanticWhereFilter(where_sql_template=" {Dimension('customer__has_delivery_address')} ") + entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order') }}") + filter_intersection = PydanticWhereFilterIntersection( + where_filters=[metric_time_in_dimension_error, valid_dimension, entity_format_error] + ) + + with pytest.raises(ParseWhereFilterException) as exc_info: + filter_intersection.filter_expression_parameter_sets + + error_string = str(exc_info.value) + # These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find. + assert ParameterSetFactory._exception_message_for_incorrect_format("order_id__order_time__month") in error_string + assert "Entity name is in an incorrect format: 'order_id__is_food_order'" in error_string + # We cannot simply scan for name because the error message contains the filter list, so we assert against the error + assert ( + ParameterSetFactory._exception_message_for_incorrect_format("customer__has_delivery_address") + not in error_string + ) diff --git a/tests/parsing/test_metric_parsing.py b/tests/parsing/test_metric_parsing.py index d0c07d71..dba46209 100644 --- a/tests/parsing/test_metric_parsing.py +++ b/tests/parsing/test_metric_parsing.py @@ -2,6 +2,7 @@ from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, + PydanticWhereFilterIntersection, ) from dbt_semantic_interfaces.implementations.metric import ( PydanticMetricInput, @@ -66,7 +67,9 @@ def test_legacy_metric_input_measure_object_parsing() -> None: metric = build_result.semantic_manifest.metrics[0] assert metric.type_params.measure == PydanticMetricInputMeasure( name="legacy_measure_from_object", - filter=PydanticWhereFilter(where_sql_template="""{{ dimension('some_bool') }}"""), + filter=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="""{{ dimension('some_bool') }}""")] + ), join_to_timespine=True, fill_nulls_with=1, ) @@ -181,8 +184,12 @@ def test_ratio_metric_input_measure_object_parsing() -> None: metric = build_result.semantic_manifest.metrics[0] assert metric.type_params.numerator == PydanticMetricInput( name="numerator_metric_from_object", - filter=PydanticWhereFilter( - where_sql_template="some_number > 5", + filter=PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter( + where_sql_template="some_number > 5", + ) + ], ), ) assert metric.type_params.denominator == PydanticMetricInput(name="denominator_metric_from_object") @@ -328,8 +335,41 @@ def test_constraint_metric_parsing() -> None: metric = build_result.semantic_manifest.metrics[0] assert metric.name == "constraint_test" assert metric.type is MetricType.SIMPLE - assert metric.filter == PydanticWhereFilter( - where_sql_template="{{ dimension('some_dimension') }} IN ('value1', 'value2')" + assert metric.filter == PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('some_dimension') }} IN ('value1', 'value2')") + ] + ) + + +def test_constraint_list_metric_parsing() -> None: + """Test for parsing a metric specification with a list of constraints included.""" + yaml_contents = textwrap.dedent( + """\ + metric: + name: constraint_test + type: simple + type_params: + measure: + name: input_measure + filter: + - "{{ dimension('some_dimension') }} IN ('value1', 'value2')" + - "1 > 0" + """ + ) + file = YamlConfigFile(filepath="inline_for_test", contents=yaml_contents) + + build_result = parse_yaml_files_to_semantic_manifest(files=[file, EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE]) + + assert len(build_result.semantic_manifest.metrics) == 1 + metric = build_result.semantic_manifest.metrics[0] + assert metric.name == "constraint_test" + assert metric.type is MetricType.SIMPLE + assert metric.filter == PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('some_dimension') }} IN ('value1', 'value2')"), + PydanticWhereFilter(where_sql_template="1 > 0"), + ] ) @@ -364,7 +404,9 @@ def test_derived_metric_input_parsing() -> None: assert metric.type_params.metrics[1] == PydanticMetricInput( name="input_metric", alias="constrained_input_metric", - filter=PydanticWhereFilter(where_sql_template="input_metric < 10"), + filter=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="input_metric < 10")] + ), ) diff --git a/tests/parsing/test_saved_query_parsing.py b/tests/parsing/test_saved_query_parsing.py index 596ee66a..95b0e6aa 100644 --- a/tests/parsing/test_saved_query_parsing.py +++ b/tests/parsing/test_saved_query_parsing.py @@ -131,5 +131,6 @@ def test_saved_query_where() -> None: build_result = parse_yaml_files_to_semantic_manifest(files=[file, EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE]) assert len(build_result.semantic_manifest.saved_queries) == 1 saved_query = build_result.semantic_manifest.saved_queries[0] - assert len(saved_query.where) == 1 - assert where == saved_query.where[0].where_sql_template + assert saved_query.where is not None + assert len(saved_query.where.where_filters) == 1 + assert where == saved_query.where.where_filters[0].where_sql_template diff --git a/tests/parsing/test_where_filter_parsing.py b/tests/parsing/test_where_filter_parsing.py new file mode 100644 index 00000000..11eed4d3 --- /dev/null +++ b/tests/parsing/test_where_filter_parsing.py @@ -0,0 +1,137 @@ +"""Tests various where filter parsing conditions. + +WhereFilter parsing operations can be fairly complex, as they must be able to accept input that is +either a bare string filter expression or some partially or fully deserialized filter object type. + +In addition, due to the migration from WhereFilter to WhereFilterIntersection types, this tests the +various conversion operations we will need to perform on semantic manifests defined out in the world. + +This module tests the various combinations we might encounter in the wild, with a particular focus +on inputs to parse_obj or parse_raw, as that is what the pydantic models will generally encounter. +""" + + +from dbt_semantic_interfaces.implementations.base import HashableBaseModel +from dbt_semantic_interfaces.implementations.filters.where_filter import ( + PydanticWhereFilter, + PydanticWhereFilterIntersection, +) + +__BOOLEAN_EXPRESSION__ = "1 > 0" + + +class ModelWithWhereFilter(HashableBaseModel): + """Defines a test model to allow for evaluation of different parsing modes for where filter expressions.""" + + where_filter: PydanticWhereFilter + + +class ModelWithWhereFilterIntersection(HashableBaseModel): + """Defines a test model to allow for evaluation of different parsing modes for where filter intersections. + + This has the same schema, apart from the filter type, as the ModelWithWhereFilter in order to allow for + testing conversion from a WhereFilter to a WhereFilterIntersection. + """ + + where_filter: PydanticWhereFilterIntersection + + +def test_partially_deserialized_object_string_parsing() -> None: + """Tests parsing a where filter specified as a string within partially deserialized json object.""" + obj = {"where_filter": __BOOLEAN_EXPRESSION__} + + parsed_model = ModelWithWhereFilter.parse_obj(obj) + + assert parsed_model.where_filter == PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__) + + +def test_partially_deserialized_object_parsing() -> None: + """Tests parsing a where filter that was serialized and then json decoded, but not fully parsed.""" + obj = {"where_filter": {"where_sql_template": __BOOLEAN_EXPRESSION__}} + + parsed_model = ModelWithWhereFilter.parse_obj(obj) + + assert parsed_model.where_filter == PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__) + + +def test_injected_object_parsing() -> None: + """Tests parsing where, for some reason, a PydanticWhereFilter has been injected into the object. + + This covers the (hopefully vanishingly rare) cases where some raw validator in a pydantic implementation + is updating the input object to convert something to a PydanticWhereFilter. + """ + obj = {"where_filter": PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__)} + + parsed_model = ModelWithWhereFilter.parse_obj(obj) + + assert parsed_model.where_filter == PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__) + + +def test_serialize_deserialize_operations() -> None: + """Tests serializing and deserializing an object with a WhereFilter. + + This should cover the most common scenarios, where we need to parse a serialized SemanticManifest. + """ + base_obj = ModelWithWhereFilter(where_filter=PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__)) + + serialized = base_obj.json() + deserialized = ModelWithWhereFilter.parse_raw(serialized) + + assert deserialized == base_obj + + +def test_conversion_from_partially_deserialized_where_filter_string() -> None: + """Tests converting a partially deserialized ModelWithWhereFilter into a ModelWithWhereFilterIntersection. + + This covers the case where the input is still a bare string, such as might happen in a raw YAML read. + """ + obj = {"where_filter": __BOOLEAN_EXPRESSION__} + expected_conversion_output = PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__)] + ) + + parsed_model = ModelWithWhereFilterIntersection.parse_obj(obj) + + assert parsed_model.where_filter == expected_conversion_output + + +def test_conversion_from_partially_deserialized_where_filter_object() -> None: + """Tests converting a partially deserialized WhereFilter into a WhereFilterIntersection.""" + obj = {"where_filter": {"where_sql_template": __BOOLEAN_EXPRESSION__}} + expected_conversion_output = PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__)] + ) + + parsed_model = ModelWithWhereFilterIntersection.parse_obj(obj) + + assert parsed_model.where_filter == expected_conversion_output + + +def test_conversion_from_injected_where_filter_object() -> None: + """Tests conversion from a PydanticWhereFilter instance, such as one inserted via a raw validator.""" + obj = {"where_filter": PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__)} + expected_conversion_output = PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__)] + ) + + parsed_model = ModelWithWhereFilterIntersection.parse_obj(obj) + + assert parsed_model.where_filter == expected_conversion_output + + +def test_where_filter_intersection_from_partially_deserialized_list_of_strings() -> None: + """Tests parsing a PydanticWhereFilterIntersection when the input is a list of strings. + + This simulates handling YAML input, which may be a list or other sequence of filters. + """ + obj = {"where_filter": [__BOOLEAN_EXPRESSION__, "0 < 1"]} + expected_parsed_output = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template=__BOOLEAN_EXPRESSION__), + PydanticWhereFilter(where_sql_template="0 < 1"), + ] + ) + + parsed_model = ModelWithWhereFilterIntersection.parse_obj(obj) + + assert parsed_model.where_filter == expected_parsed_output diff --git a/tests/validations/test_metrics.py b/tests/validations/test_metrics.py index 6db78b0b..9b8fbedf 100644 --- a/tests/validations/test_metrics.py +++ b/tests/validations/test_metrics.py @@ -8,12 +8,15 @@ ) from dbt_semantic_interfaces.implementations.elements.entity import PydanticEntity from dbt_semantic_interfaces.implementations.elements.measure import PydanticMeasure +from dbt_semantic_interfaces.implementations.filters.where_filter import ( + PydanticWhereFilter, + PydanticWhereFilterIntersection, +) from dbt_semantic_interfaces.implementations.metric import ( PydanticMetricInput, PydanticMetricInputMeasure, PydanticMetricTimeWindow, PydanticMetricTypeParams, - PydanticWhereFilter, ) from dbt_semantic_interfaces.implementations.semantic_manifest import ( PydanticSemanticManifest, @@ -323,7 +326,8 @@ def test_where_filter_validations_bad_base_filter( # noqa: D metric, _ = find_metric_with(manifest, lambda metric: metric.filter is not None) assert metric.filter is not None - metric.filter.where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" + assert len(metric.filter.where_filters) > 0 + metric.filter.where_filters[0].where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises(SemanticManifestValidationException, match=f"trying to parse filter of metric `{metric.name}`"): validator.checked_validations(manifest) @@ -338,8 +342,10 @@ def test_where_filter_validations_bad_measure_filter( # noqa: D manifest, lambda metric: metric.type_params is not None and metric.type_params.measure is not None ) assert metric.type_params.measure is not None - metric.type_params.measure.filter = PydanticWhereFilter( - where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" + metric.type_params.measure.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] ) validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises( @@ -358,8 +364,10 @@ def test_where_filter_validations_bad_numerator_filter( # noqa: D manifest, lambda metric: metric.type_params is not None and metric.type_params.numerator is not None ) assert metric.type_params.numerator is not None - metric.type_params.numerator.filter = PydanticWhereFilter( - where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" + metric.type_params.numerator.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] ) validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises( @@ -377,8 +385,10 @@ def test_where_filter_validations_bad_denominator_filter( # noqa: D manifest, lambda metric: metric.type_params is not None and metric.type_params.denominator is not None ) assert metric.type_params.denominator is not None - metric.type_params.denominator.filter = PydanticWhereFilter( - where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" + metric.type_params.denominator.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] ) validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises( @@ -400,8 +410,10 @@ def test_where_filter_validations_bad_input_metric_filter( # noqa: D ) assert metric.type_params.metrics is not None input_metric = metric.type_params.metrics[0] - input_metric.filter = PydanticWhereFilter( - where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" + input_metric.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] ) validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) with pytest.raises( diff --git a/tests/validations/test_saved_query.py b/tests/validations/test_saved_query.py index c6ae46f4..89ba8289 100644 --- a/tests/validations/test_saved_query.py +++ b/tests/validations/test_saved_query.py @@ -3,6 +3,7 @@ from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, + PydanticWhereFilterIntersection, ) from dbt_semantic_interfaces.implementations.saved_query import PydanticSavedQuery from dbt_semantic_interfaces.implementations.semantic_manifest import ( @@ -44,7 +45,9 @@ def test_invalid_metric_in_saved_query( # noqa: D description="Example description.", metrics=["invalid_metric"], group_bys=["Dimension('booking__is_instant')"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ] @@ -64,7 +67,9 @@ def test_invalid_where_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["Dimension('booking__is_instant')"], - where=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], + ), ), ] @@ -85,7 +90,9 @@ def test_invalid_group_by_element_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["Dimension('booking__invalid_dimension')"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ] @@ -106,7 +113,9 @@ def test_invalid_group_by_format_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["invalid_format"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ]