diff --git a/.changes/unreleased/Features-20231004-102255.yaml b/.changes/unreleased/Features-20231004-102255.yaml new file mode 100644 index 0000000000..ebde51d332 --- /dev/null +++ b/.changes/unreleased/Features-20231004-102255.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support for the Dimension(...).grain(...) syntax for the where parameter +time: 2023-10-04T10:22:55.730467-05:00 +custom: + Author: DevonFulcher + Issue: None diff --git a/metricflow/specs/dimension_spec_resolver.py b/metricflow/specs/dimension_spec_resolver.py new file mode 100644 index 0000000000..b337afbaed --- /dev/null +++ b/metricflow/specs/dimension_spec_resolver.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import Sequence + +from dbt_semantic_interfaces.call_parameter_sets import ( + DimensionCallParameterSet, + FilterCallParameterSets, + TimeDimensionCallParameterSet, +) +from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter +from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference +from dbt_semantic_interfaces.type_enums import TimeGranularity + +from metricflow.specs.specs import DimensionSpec, TimeDimensionSpec + + +class DimensionSpecResolver: + """Resolves specs for Dimension & TimeDimension given name, grain, & entity path. Utilized in where clause in Jinja syntax.""" + + def __init__(self, call_parameter_sets: FilterCallParameterSets): # noqa + self._call_parameter_sets = call_parameter_sets + + def resolve_dimension_spec(self, name: str, entity_path: Sequence[str]) -> DimensionSpec: + """Resolve Dimension spec with the call_parameter_sets.""" + structured_name = DunderedNameFormatter.parse_name(name) + call_parameter_set = DimensionCallParameterSet( + dimension_reference=DimensionReference(element_name=structured_name.element_name), + entity_path=( + tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links + ), + ) + return DimensionSpec( + element_name=call_parameter_set.dimension_reference.element_name, + entity_links=call_parameter_set.entity_path, + ) + + def resolve_time_dimension_spec( + self, name: str, time_granularity_name: TimeGranularity, entity_path: Sequence[str] + ) -> TimeDimensionSpec: + """Resolve TimeDimension spec with the call_parameter_sets.""" + structured_name = DunderedNameFormatter.parse_name(name) + call_parameter_set = TimeDimensionCallParameterSet( + time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name), + time_granularity=time_granularity_name, + entity_path=( + tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links + ), + ) + assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets + return TimeDimensionSpec( + element_name=call_parameter_set.time_dimension_reference.element_name, + entity_links=call_parameter_set.entity_path, + time_granularity=call_parameter_set.time_granularity, + ) diff --git a/metricflow/specs/where_filter_dimension.py b/metricflow/specs/where_filter_dimension.py index 655fac2053..52720870a9 100644 --- a/metricflow/specs/where_filter_dimension.py +++ b/metricflow/specs/where_filter_dimension.py @@ -1,17 +1,12 @@ from __future__ import annotations -from typing import List, Sequence +from typing import List, Optional, Sequence from dbt_semantic_interfaces.call_parameter_sets import ( - DimensionCallParameterSet, FilterCallParameterSets, ) -from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint -from dbt_semantic_interfaces.references import ( - DimensionReference, - EntityReference, -) +from dbt_semantic_interfaces.type_enums import TimeGranularity from typing_extensions import override from metricflow.errors.errors import InvalidQuerySyntax @@ -20,7 +15,8 @@ QueryInterfaceDimensionFactory, ) from metricflow.specs.column_assoc import ColumnAssociationResolver -from metricflow.specs.specs import DimensionSpec +from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver +from metricflow.specs.specs import TimeDimensionSpec class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]): @@ -30,39 +26,49 @@ class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]): def _implements_protocol(self) -> QueryInterfaceDimension: return self - def __init__(self, column_name: str) -> None: # noqa - self.column_name = column_name + def __init__( # noqa + self, + name: str, + entity_path: Sequence[str], + call_parameter_sets: FilterCallParameterSets, + column_association_resolver: ColumnAssociationResolver, + ) -> None: + self._dimension_spec_resolver = DimensionSpecResolver(call_parameter_sets) + self._column_association_resolver = column_association_resolver + self._name = name + self._entity_path = entity_path + self.dimension_spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path) + self.time_dimension_spec: Optional[TimeDimensionSpec] = None - def grain(self, _grain: str) -> QueryInterfaceDimension: + def grain(self, time_granularity_name: str) -> QueryInterfaceDimension: """The time granularity.""" - raise NotImplementedError + self.time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec( + self._name, TimeGranularity(time_granularity_name), self._entity_path + ) + return self def date_part(self, _date_part: str) -> QueryInterfaceDimension: """The date_part requested to extract.""" - raise NotImplementedError - - def alias(self, _alias: str) -> QueryInterfaceDimension: - """Renaming the column.""" - raise NotImplementedError + raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter") def descending(self, _is_descending: bool) -> QueryInterfaceDimension: """Set the sort order for order-by.""" - raise InvalidQuerySyntax( - "Can't set descending in the where clause. Try setting descending in the order_by clause instead" - ) + raise InvalidQuerySyntax("descending is invalid in the where parameter") def __str__(self) -> str: """Returns the column name. Important in the Jinja sandbox. """ - return self.column_name + return self._column_association_resolver.resolve_spec( + self.time_dimension_spec or self.dimension_spec + ).column_name class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]): """Creates a WhereFilterDimension. - Each call to `create` adds a DimensionSpec to dimension_specs. + Each call to `create` adds a WhereFilterDimension to created. """ @override @@ -76,29 +82,12 @@ def __init__( # noqa ): self._call_parameter_sets = call_parameter_sets self._column_association_resolver = column_association_resolver - self.dimension_specs: List[DimensionSpec] = [] + self.created: List[WhereFilterDimension] = [] def create(self, name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension: """Create a WhereFilterDimension.""" - structured_name = DunderedNameFormatter.parse_name(name) - call_parameter_set = DimensionCallParameterSet( - dimension_reference=DimensionReference(element_name=structured_name.element_name), - entity_path=( - tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links - ), - ) - assert call_parameter_set in self._call_parameter_sets.dimension_call_parameter_sets - - dimension_spec = self._convert_to_dimension_spec(call_parameter_set) - self.dimension_specs.append(dimension_spec) - column_name = self._column_association_resolver.resolve_spec(dimension_spec).column_name - return WhereFilterDimension(column_name) - - def _convert_to_dimension_spec( - self, - parameter_set: DimensionCallParameterSet, - ) -> DimensionSpec: # noqa: D - return DimensionSpec( - element_name=parameter_set.dimension_reference.element_name, - entity_links=parameter_set.entity_path, + dimension = WhereFilterDimension( + name, entity_path, self._call_parameter_sets, self._column_association_resolver ) + self.created.append(dimension) + return dimension diff --git a/metricflow/specs/where_filter_time_dimension.py b/metricflow/specs/where_filter_time_dimension.py index 0ed4fceea5..7537d4b3f4 100644 --- a/metricflow/specs/where_filter_time_dimension.py +++ b/metricflow/specs/where_filter_time_dimension.py @@ -2,16 +2,15 @@ from typing import List, Optional, Sequence -from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets, TimeDimensionCallParameterSet -from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter +from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint -from dbt_semantic_interfaces.references import EntityReference, TimeDimensionReference -from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from dbt_semantic_interfaces.type_enums import TimeGranularity from typing_extensions import override from metricflow.errors.errors import InvalidQuerySyntax from metricflow.protocols.query_interface import QueryInterfaceTimeDimension, QueryInterfaceTimeDimensionFactory from metricflow.specs.column_assoc import ColumnAssociationResolver +from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver from metricflow.specs.specs import TimeDimensionSpec @@ -50,6 +49,7 @@ def __init__( # noqa ): self._call_parameter_sets = call_parameter_sets self._column_association_resolver = column_association_resolver + self._dimension_spec_resolver = DimensionSpecResolver(call_parameter_sets) self.time_dimension_specs: List[TimeDimensionSpec] = [] def create( @@ -67,27 +67,9 @@ def create( ) if date_part_name: raise InvalidQuerySyntax("date_part_name isn't currently supported in the where parameter") - structured_name = DunderedNameFormatter.parse_name(time_dimension_name) - call_parameter_set = TimeDimensionCallParameterSet( - time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name), - time_granularity=TimeGranularity(time_granularity_name), - entity_path=( - tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links - ), + time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec( + time_dimension_name, TimeGranularity(time_granularity_name), entity_path ) - assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets - - time_dimension_spec = self._convert_to_time_dimension_spec(call_parameter_set) self.time_dimension_specs.append(time_dimension_spec) - column_names = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name - return WhereFilterTimeDimension(column_names) - - def _convert_to_time_dimension_spec( - self, - parameter_set: TimeDimensionCallParameterSet, - ) -> TimeDimensionSpec: # noqa: D - return TimeDimensionSpec( - element_name=parameter_set.time_dimension_reference.element_name, - entity_links=parameter_set.entity_path, - time_granularity=parameter_set.time_granularity, - ) + column_name = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name + return WhereFilterTimeDimension(column_name) diff --git a/metricflow/specs/where_filter_transform.py b/metricflow/specs/where_filter_transform.py index 0e2937bcd6..3997d31f7d 100644 --- a/metricflow/specs/where_filter_transform.py +++ b/metricflow/specs/where_filter_transform.py @@ -52,11 +52,22 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec f"Error while rendering Jinja template:\n{where_filter.where_sql_template}" ) from e + """ + Dimensions that are created with a grain parameter, Dimension(...).grain(...), are + added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs + """ + dimension_specs = [] + for dimension in dimension_factory.created: + if dimension.time_dimension_spec: + time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec) + else: + dimension_specs.append(dimension.dimension_spec) + return WhereFilterSpec( where_sql=rendered_sql_template, bind_parameters=self._bind_parameters, linkable_spec_set=LinkableSpecSet( - dimension_specs=tuple(dimension_factory.dimension_specs), + dimension_specs=tuple(dimension_specs), time_dimension_specs=tuple(time_dimension_factory.time_dimension_specs), entity_specs=tuple(entity_factory.entity_specs), ), diff --git a/metricflow/test/model/test_where_filter_spec.py b/metricflow/test/model/test_where_filter_spec.py index d338bc634e..428fa17c21 100644 --- a/metricflow/test/model/test_where_filter_spec.py +++ b/metricflow/test/model/test_where_filter_spec.py @@ -37,6 +37,31 @@ def test_dimension_in_filter( # noqa: D ) +def test_dimension_in_filter_with_grain( # noqa: D + column_association_resolver: ColumnAssociationResolver, +) -> None: + where_filter = PydanticWhereFilter( + where_sql_template="{{ Dimension('listing__country_latest').grain('WEEK') }} = 'US'" + ) + + where_filter_spec = WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter(where_filter) + + assert where_filter_spec.where_sql == "listing__country_latest__week = 'US'" + assert where_filter_spec.linkable_spec_set == LinkableSpecSet( + dimension_specs=(), + time_dimension_specs=( + TimeDimensionSpec( + element_name="country_latest", + entity_links=(EntityReference(element_name="listing"),), + time_granularity=TimeGranularity.WEEK, + ), + ), + entity_specs=(), + ) + + def test_time_dimension_in_filter( # noqa: D column_association_resolver: ColumnAssociationResolver, ) -> None: diff --git a/metricflow/test/specs/test_where_filter_dimension.py b/metricflow/test/specs/test_where_filter_dimension.py deleted file mode 100644 index a3c673daae..0000000000 --- a/metricflow/test/specs/test_where_filter_dimension.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -import pytest - -from metricflow.errors.errors import InvalidQuerySyntax -from metricflow.specs.where_filter_dimension import WhereFilterDimension - - -def test_descending_cannot_be_set() -> None: # noqa - with pytest.raises(InvalidQuerySyntax): - WhereFilterDimension("bookings").descending(True)