From bd01bf03d02f1fa9a140a23ea7838df756b613ef Mon Sep 17 00:00:00 2001 From: tlento Date: Mon, 9 Oct 2023 20:46:41 -0700 Subject: [PATCH] Update SavedQuery objects to WhereFilterIntersection For the sake of consistency, and everybody's sanity, we should have the new things use the same constructs as the old things. While the original Sequence[WhereFilter] type is nominally fine, it would become rather annoying to have to deal with lists in some places and objects containing lists in others, especially when collecting call parameter sets. --- .../implementations/saved_query.py | 4 ++-- .../protocols/saved_query.py | 5 +++-- .../validations/saved_query.py | 4 +++- tests/parsing/test_saved_query_parsing.py | 5 +++-- tests/validations/test_saved_query.py | 17 +++++++++++++---- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/dbt_semantic_interfaces/implementations/saved_query.py b/dbt_semantic_interfaces/implementations/saved_query.py index 6ff709b9..53de5038 100644 --- a/dbt_semantic_interfaces/implementations/saved_query.py +++ b/dbt_semantic_interfaces/implementations/saved_query.py @@ -9,7 +9,7 @@ ModelWithMetadataParsing, ) from dbt_semantic_interfaces.implementations.filters.where_filter import ( - PydanticWhereFilter, + PydanticWhereFilterIntersection, ) from dbt_semantic_interfaces.implementations.metadata import PydanticMetadata from dbt_semantic_interfaces.protocols import ProtocolHint @@ -26,7 +26,7 @@ def _implements_protocol(self) -> SavedQuery: name: str metrics: List[str] group_bys: List[str] = [] - where: List[PydanticWhereFilter] = [] + where: Optional[PydanticWhereFilterIntersection] = None description: Optional[str] = None metadata: Optional[PydanticMetadata] = None diff --git a/dbt_semantic_interfaces/protocols/saved_query.py b/dbt_semantic_interfaces/protocols/saved_query.py index 2018b164..3bd739d9 100644 --- a/dbt_semantic_interfaces/protocols/saved_query.py +++ b/dbt_semantic_interfaces/protocols/saved_query.py @@ -2,7 +2,7 @@ from typing import Optional, Protocol, Sequence from dbt_semantic_interfaces.protocols.metadata import Metadata -from dbt_semantic_interfaces.protocols.where_filter import WhereFilter +from dbt_semantic_interfaces.protocols.where_filter import WhereFilterIntersection class SavedQuery(Protocol): @@ -35,7 +35,8 @@ def group_bys(self) -> Sequence[str]: # noqa: D @property @abstractmethod - def where(self) -> Sequence[WhereFilter]: # noqa: D + def where(self) -> Optional[WhereFilterIntersection]: + """Returns the intersection class containing any where filters specified in the saved query.""" pass @property diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index 0b2ecd4c..f9abd7f4 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -101,7 +101,9 @@ def _check_metrics(valid_metric_names: Set[str], saved_query: SavedQuery) -> Seq @validate_safely("Validate the where field in a saved query.") def _check_where(saved_query: SavedQuery) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] - for where_filter in saved_query.where: + if saved_query.where is None: + return issues + for where_filter in saved_query.where.where_filters: try: where_filter.call_parameter_sets except Exception as e: diff --git a/tests/parsing/test_saved_query_parsing.py b/tests/parsing/test_saved_query_parsing.py index 596ee66a..95b0e6aa 100644 --- a/tests/parsing/test_saved_query_parsing.py +++ b/tests/parsing/test_saved_query_parsing.py @@ -131,5 +131,6 @@ def test_saved_query_where() -> None: build_result = parse_yaml_files_to_semantic_manifest(files=[file, EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE]) assert len(build_result.semantic_manifest.saved_queries) == 1 saved_query = build_result.semantic_manifest.saved_queries[0] - assert len(saved_query.where) == 1 - assert where == saved_query.where[0].where_sql_template + assert saved_query.where is not None + assert len(saved_query.where.where_filters) == 1 + assert where == saved_query.where.where_filters[0].where_sql_template diff --git a/tests/validations/test_saved_query.py b/tests/validations/test_saved_query.py index c6ae46f4..89ba8289 100644 --- a/tests/validations/test_saved_query.py +++ b/tests/validations/test_saved_query.py @@ -3,6 +3,7 @@ from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, + PydanticWhereFilterIntersection, ) from dbt_semantic_interfaces.implementations.saved_query import PydanticSavedQuery from dbt_semantic_interfaces.implementations.semantic_manifest import ( @@ -44,7 +45,9 @@ def test_invalid_metric_in_saved_query( # noqa: D description="Example description.", metrics=["invalid_metric"], group_bys=["Dimension('booking__is_instant')"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ] @@ -64,7 +67,9 @@ def test_invalid_where_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["Dimension('booking__is_instant')"], - where=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], + ), ), ] @@ -85,7 +90,9 @@ def test_invalid_group_by_element_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["Dimension('booking__invalid_dimension')"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ] @@ -106,7 +113,9 @@ def test_invalid_group_by_format_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["invalid_format"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ]