Skip to content

Commit

Permalink
Add accessor for collected filter call parameter sets to WhereFilterI…
Browse files Browse the repository at this point in the history
…ntersection

The call_parameter_sets for each of the WhereFilters contained in a
WhereFilterIntersection currently have to be accessed one at a time in a list.
In addition to making it harder to run sensible validations against an implementation
of the WhereFilterIntersection, this also complicates runtime processing for any
implementation (e.g., MetricFlow) that needs to access these parameter sets as
a collection.

This adds a property to the protocol spec for getting a sequence of pairs between
the filter expression sql and the call parameter sets it contains, which allows
for downstream flexibility for managing the WhereFilter components of a
WhereFilterIntersection.
  • Loading branch information
tlento committed Oct 10, 2023
1 parent 8a3cdd0 commit aa4fcda
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 3 deletions.
28 changes: 26 additions & 2 deletions dbt_semantic_interfaces/implementations/filters/where_filter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from typing import Callable, Generator, List
from typing import Callable, Generator, List, Tuple

from typing_extensions import Self

from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets
from dbt_semantic_interfaces.call_parameter_sets import (
FilterCallParameterSets,
ParseWhereFilterException,
)
from dbt_semantic_interfaces.implementations.base import (
HashableBaseModel,
PydanticCustomInputParser,
Expand All @@ -13,6 +16,7 @@
from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import (
WhereFilterParser,
)
from dbt_semantic_interfaces.pretty_print import pformat_big_objects


class PydanticWhereFilter(PydanticCustomInputParser, HashableBaseModel):
Expand Down Expand Up @@ -109,3 +113,23 @@ def _convert_legacy_and_yaml_input(cls, input: PydanticParseableValueType) -> Se
f"Expected input to be of type string, list, PydanticWhereFilter, PydanticWhereFilterIntersection, "
f"or dict but got {type(input)} with value {input}"
)

@property
def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParameterSets]]:
"""Gets the call parameter sets for each filter expression."""
filter_parameter_sets: List[Tuple[str, FilterCallParameterSets]] = []
invalid_filter_expressions: List[Tuple[str, Exception]] = []
for where_filter in self.where_filters:
try:
filter_parameter_sets.append((where_filter.where_sql_template, where_filter.call_parameter_sets))
except Exception as e:
invalid_filter_expressions.append((where_filter.where_sql_template, e))

if invalid_filter_expressions:
raise ParseWhereFilterException(
f"Encountered one or more errors when parsing the set of filter expressions "
f"{pformat_big_objects(self.where_filters)}! Invalid expressions: \n "
f"{pformat_big_objects(invalid_filter_expressions)}"
)

return filter_parameter_sets
12 changes: 11 additions & 1 deletion dbt_semantic_interfaces/protocols/where_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Protocol, Sequence
from typing import Protocol, Sequence, Tuple

from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets

Expand Down Expand Up @@ -40,3 +40,13 @@ class WhereFilterIntersection(Protocol):
def where_filters(self) -> Sequence[WhereFilter]:
"""The collection of WhereFilters to be applied to the input data set."""
pass

@property
@abstractmethod
def filter_expression_parameter_sets(self) -> Sequence[Tuple[str, FilterCallParameterSets]]:
"""Mapping from distinct filter expressions to the call parameter sets associated with them.
We use a tuple, rather than a Mapping, in case the call parameter sets may vary between
filter expression specifications.
"""
pass
70 changes: 70 additions & 0 deletions tests/implementations/where_filter/test_parse_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
)
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
PydanticWhereFilterIntersection,
)
from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import (
ParameterSetFactory,
)
from dbt_semantic_interfaces.references import (
DimensionReference,
Expand Down Expand Up @@ -145,3 +149,69 @@ def test_invalid_entity_name_error() -> None:

with pytest.raises(ParseWhereFilterException, match="Entity name is in an incorrect format"):
bad_entity_filter.call_parameter_sets


def test_where_filter_interesection_extract_call_parameter_sets() -> None:
"""Tests the collection of call parameter sets for a set of where filters."""
time_filter = PydanticWhereFilter(
where_sql_template=("""{{ TimeDimension('metric_time', 'month') }} = '2020-01-01'""")
)
entity_filter = PydanticWhereFilter(
where_sql_template=(
"""{{ Entity('listing') }} AND {{ Entity('user', entity_path=['listing']) }} == 'TEST_USER_ID'"""
)
)
filter_intersection = PydanticWhereFilterIntersection(where_filters=[time_filter, entity_filter])

parse_result = dict(filter_intersection.filter_expression_parameter_sets)

assert parse_result.get(time_filter.where_sql_template) == FilterCallParameterSets(
time_dimension_call_parameter_sets=(
TimeDimensionCallParameterSet(
time_dimension_reference=TimeDimensionReference(element_name="metric_time"),
entity_path=(),
time_granularity=TimeGranularity.MONTH,
),
)
)
assert parse_result.get(entity_filter.where_sql_template) == FilterCallParameterSets(
dimension_call_parameter_sets=(),
entity_call_parameter_sets=(
EntityCallParameterSet(
entity_path=(),
entity_reference=EntityReference("listing"),
),
EntityCallParameterSet(
entity_path=(EntityReference("listing"),),
entity_reference=EntityReference("user"),
),
),
)


def test_where_filter_intersection_error_collection() -> None:
"""Tests the error behaviors when parsing where filters and collecting the call parameter sets for each.
This should result in a single exception with all broken filters represented.
"""
metric_time_in_dimension_error = PydanticWhereFilter(
where_sql_template="{{ TimeDimension('order_id__order_time__month', 'week') }} > '2020-01-01'"
)
valid_dimension = PydanticWhereFilter(where_sql_template=" {Dimension('customer__has_delivery_address')} ")
entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order') }}")
filter_intersection = PydanticWhereFilterIntersection(
where_filters=[metric_time_in_dimension_error, valid_dimension, entity_format_error]
)

with pytest.raises(ParseWhereFilterException) as exc_info:
filter_intersection.filter_expression_parameter_sets

error_string = str(exc_info.value)
# These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find.
assert ParameterSetFactory._exception_message_for_incorrect_format("order_id__order_time__month") in error_string
assert "Entity name is in an incorrect format: 'order_id__is_food_order'" in error_string
# We cannot simply scan for name because the error message contains the filter list, so we assert against the error
assert (
ParameterSetFactory._exception_message_for_incorrect_format("customer__has_delivery_address")
not in error_string
)

0 comments on commit aa4fcda

Please sign in to comment.