Skip to content

Commit

Permalink
Support for Dimension(...).grain(...) in where filter (#785)
Browse files Browse the repository at this point in the history
This PR adds support for the Dimension(...).grain(...) syntax for the where parameter and filter spec by using the factory pattern & implementing the Query Interface protocols.
  • Loading branch information
DevonFulcher authored Oct 6, 2023
1 parent acbeeef commit dbbff3e
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 82 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231004-102255.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support for the Dimension(...).grain(...) syntax for the where parameter
time: 2023-10-04T10:22:55.730467-05:00
custom:
Author: DevonFulcher
Issue: None
54 changes: 54 additions & 0 deletions metricflow/specs/dimension_spec_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from typing import Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
FilterCallParameterSets,
TimeDimensionCallParameterSet,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums import TimeGranularity

from metricflow.specs.specs import DimensionSpec, TimeDimensionSpec


class DimensionSpecResolver:
"""Resolves specs for Dimension & TimeDimension given name, grain, & entity path. Utilized in where clause in Jinja syntax."""

def __init__(self, call_parameter_sets: FilterCallParameterSets): # noqa
self._call_parameter_sets = call_parameter_sets

def resolve_dimension_spec(self, name: str, entity_path: Sequence[str]) -> DimensionSpec:
"""Resolve Dimension spec with the call_parameter_sets."""
structured_name = DunderedNameFormatter.parse_name(name)
call_parameter_set = DimensionCallParameterSet(
dimension_reference=DimensionReference(element_name=structured_name.element_name),
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
)
return DimensionSpec(
element_name=call_parameter_set.dimension_reference.element_name,
entity_links=call_parameter_set.entity_path,
)

def resolve_time_dimension_spec(
self, name: str, time_granularity_name: TimeGranularity, entity_path: Sequence[str]
) -> TimeDimensionSpec:
"""Resolve TimeDimension spec with the call_parameter_sets."""
structured_name = DunderedNameFormatter.parse_name(name)
call_parameter_set = TimeDimensionCallParameterSet(
time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name),
time_granularity=time_granularity_name,
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
)
assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets
return TimeDimensionSpec(
element_name=call_parameter_set.time_dimension_reference.element_name,
entity_links=call_parameter_set.entity_path,
time_granularity=call_parameter_set.time_granularity,
)
77 changes: 33 additions & 44 deletions metricflow/specs/where_filter_dimension.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from __future__ import annotations

from typing import List, Sequence
from typing import List, Optional, Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
FilterCallParameterSets,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
Expand All @@ -20,7 +15,8 @@
QueryInterfaceDimensionFactory,
)
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import DimensionSpec
from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver
from metricflow.specs.specs import TimeDimensionSpec


class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]):
Expand All @@ -30,39 +26,49 @@ class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]):
def _implements_protocol(self) -> QueryInterfaceDimension:
return self

def __init__(self, column_name: str) -> None: # noqa
self.column_name = column_name
def __init__( # noqa
self,
name: str,
entity_path: Sequence[str],
call_parameter_sets: FilterCallParameterSets,
column_association_resolver: ColumnAssociationResolver,
) -> None:
self._dimension_spec_resolver = DimensionSpecResolver(call_parameter_sets)
self._column_association_resolver = column_association_resolver
self._name = name
self._entity_path = entity_path
self.dimension_spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path)
self.time_dimension_spec: Optional[TimeDimensionSpec] = None

def grain(self, _grain: str) -> QueryInterfaceDimension:
def grain(self, time_granularity_name: str) -> QueryInterfaceDimension:
"""The time granularity."""
raise NotImplementedError
self.time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec(
self._name, TimeGranularity(time_granularity_name), self._entity_path
)
return self

def date_part(self, _date_part: str) -> QueryInterfaceDimension:
"""The date_part requested to extract."""
raise NotImplementedError

def alias(self, _alias: str) -> QueryInterfaceDimension:
"""Renaming the column."""
raise NotImplementedError
raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter")

def descending(self, _is_descending: bool) -> QueryInterfaceDimension:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax(
"Can't set descending in the where clause. Try setting descending in the order_by clause instead"
)
raise InvalidQuerySyntax("descending is invalid in the where parameter")

def __str__(self) -> str:
"""Returns the column name.
Important in the Jinja sandbox.
"""
return self.column_name
return self._column_association_resolver.resolve_spec(
self.time_dimension_spec or self.dimension_spec
).column_name


class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]):
"""Creates a WhereFilterDimension.
Each call to `create` adds a DimensionSpec to dimension_specs.
Each call to `create` adds a WhereFilterDimension to created.
"""

@override
Expand All @@ -76,29 +82,12 @@ def __init__( # noqa
):
self._call_parameter_sets = call_parameter_sets
self._column_association_resolver = column_association_resolver
self.dimension_specs: List[DimensionSpec] = []
self.created: List[WhereFilterDimension] = []

def create(self, name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension:
"""Create a WhereFilterDimension."""
structured_name = DunderedNameFormatter.parse_name(name)
call_parameter_set = DimensionCallParameterSet(
dimension_reference=DimensionReference(element_name=structured_name.element_name),
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
)
assert call_parameter_set in self._call_parameter_sets.dimension_call_parameter_sets

dimension_spec = self._convert_to_dimension_spec(call_parameter_set)
self.dimension_specs.append(dimension_spec)
column_name = self._column_association_resolver.resolve_spec(dimension_spec).column_name
return WhereFilterDimension(column_name)

def _convert_to_dimension_spec(
self,
parameter_set: DimensionCallParameterSet,
) -> DimensionSpec: # noqa: D
return DimensionSpec(
element_name=parameter_set.dimension_reference.element_name,
entity_links=parameter_set.entity_path,
dimension = WhereFilterDimension(
name, entity_path, self._call_parameter_sets, self._column_association_resolver
)
self.created.append(dimension)
return dimension
34 changes: 8 additions & 26 deletions metricflow/specs/where_filter_time_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

from typing import List, Optional, Sequence

from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets, TimeDimensionCallParameterSet
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.references import EntityReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from dbt_semantic_interfaces.type_enums import TimeGranularity
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
from metricflow.protocols.query_interface import QueryInterfaceTimeDimension, QueryInterfaceTimeDimensionFactory
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver
from metricflow.specs.specs import TimeDimensionSpec


Expand Down Expand Up @@ -50,6 +49,7 @@ def __init__( # noqa
):
self._call_parameter_sets = call_parameter_sets
self._column_association_resolver = column_association_resolver
self._dimension_spec_resolver = DimensionSpecResolver(call_parameter_sets)
self.time_dimension_specs: List[TimeDimensionSpec] = []

def create(
Expand All @@ -67,27 +67,9 @@ def create(
)
if date_part_name:
raise InvalidQuerySyntax("date_part_name isn't currently supported in the where parameter")
structured_name = DunderedNameFormatter.parse_name(time_dimension_name)
call_parameter_set = TimeDimensionCallParameterSet(
time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name),
time_granularity=TimeGranularity(time_granularity_name),
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec(
time_dimension_name, TimeGranularity(time_granularity_name), entity_path
)
assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets

time_dimension_spec = self._convert_to_time_dimension_spec(call_parameter_set)
self.time_dimension_specs.append(time_dimension_spec)
column_names = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name
return WhereFilterTimeDimension(column_names)

def _convert_to_time_dimension_spec(
self,
parameter_set: TimeDimensionCallParameterSet,
) -> TimeDimensionSpec: # noqa: D
return TimeDimensionSpec(
element_name=parameter_set.time_dimension_reference.element_name,
entity_links=parameter_set.entity_path,
time_granularity=parameter_set.time_granularity,
)
column_name = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name
return WhereFilterTimeDimension(column_name)
13 changes: 12 additions & 1 deletion metricflow/specs/where_filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,22 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec
f"Error while rendering Jinja template:\n{where_filter.where_sql_template}"
) from e

"""
Dimensions that are created with a grain parameter, Dimension(...).grain(...), are
added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs
"""
dimension_specs = []
for dimension in dimension_factory.created:
if dimension.time_dimension_spec:
time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec)
else:
dimension_specs.append(dimension.dimension_spec)

return WhereFilterSpec(
where_sql=rendered_sql_template,
bind_parameters=self._bind_parameters,
linkable_spec_set=LinkableSpecSet(
dimension_specs=tuple(dimension_factory.dimension_specs),
dimension_specs=tuple(dimension_specs),
time_dimension_specs=tuple(time_dimension_factory.time_dimension_specs),
entity_specs=tuple(entity_factory.entity_specs),
),
Expand Down
25 changes: 25 additions & 0 deletions metricflow/test/model/test_where_filter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ def test_dimension_in_filter( # noqa: D
)


def test_dimension_in_filter_with_grain( # noqa: D
column_association_resolver: ColumnAssociationResolver,
) -> None:
where_filter = PydanticWhereFilter(
where_sql_template="{{ Dimension('listing__country_latest').grain('WEEK') }} = 'US'"
)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(where_filter)

assert where_filter_spec.where_sql == "listing__country_latest__week = 'US'"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(
TimeDimensionSpec(
element_name="country_latest",
entity_links=(EntityReference(element_name="listing"),),
time_granularity=TimeGranularity.WEEK,
),
),
entity_specs=(),
)


def test_time_dimension_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver,
) -> None:
Expand Down
11 changes: 0 additions & 11 deletions metricflow/test/specs/test_where_filter_dimension.py

This file was deleted.

0 comments on commit dbbff3e

Please sign in to comment.