Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Dimension(...).grain(...) in where filter #785

Merged
merged 11 commits into from
Oct 6, 2023
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231004-102255.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support for the Dimension(...).grain(...) syntax for the where parameter
time: 2023-10-04T10:22:55.730467-05:00
custom:
Author: DevonFulcher
Issue: None
54 changes: 54 additions & 0 deletions metricflow/specs/dimension_spec_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from typing import Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
FilterCallParameterSets,
TimeDimensionCallParameterSet,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums import TimeGranularity

from metricflow.specs.specs import DimensionSpec, TimeDimensionSpec


class DimensionSpecResolver:
"""Resolves specs for Dimension & TimeDimension given name, grain, & entity path. Utilized in where clause in Jinja syntax."""

def __init__(self, call_parameter_sets: FilterCallParameterSets): # noqa
self._call_parameter_sets = call_parameter_sets

def resolve_dimension_spec(self, name: str, entity_path: Sequence[str]) -> DimensionSpec:
"""Resolve Dimension spec with the call_parameter_sets."""
structured_name = DunderedNameFormatter.parse_name(name)
call_parameter_set = DimensionCallParameterSet(
dimension_reference=DimensionReference(element_name=structured_name.element_name),
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
)
return DimensionSpec(
element_name=call_parameter_set.dimension_reference.element_name,
entity_links=call_parameter_set.entity_path,
)

def resolve_time_dimension_spec(
self, name: str, time_granularity_name: TimeGranularity, entity_path: Sequence[str]
) -> TimeDimensionSpec:
"""Resolve TimeDimension spec with the call_parameter_sets."""
structured_name = DunderedNameFormatter.parse_name(name)
call_parameter_set = TimeDimensionCallParameterSet(
time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name),
time_granularity=time_granularity_name,
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
)
assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets
return TimeDimensionSpec(
element_name=call_parameter_set.time_dimension_reference.element_name,
entity_links=call_parameter_set.entity_path,
time_granularity=call_parameter_set.time_granularity,
Comment on lines +23 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this logic already exist somewhere in MetricFlow? Or am I confusing that with this logic existing in some other repo?

I might be thinking of the deleted code below, too. 🤷

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure! This code was moved here from elsewhere in this PR, which was also moved there from somewhere that @plypaul put it. So, Paul is the original author, and I just moved it here. He may know if this is duplicative of somewhere else in MF.

)
77 changes: 33 additions & 44 deletions metricflow/specs/where_filter_dimension.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from __future__ import annotations

from typing import List, Sequence
from typing import List, Optional, Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
FilterCallParameterSets,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
Expand All @@ -20,7 +15,8 @@
QueryInterfaceDimensionFactory,
)
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import DimensionSpec
from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver
from metricflow.specs.specs import TimeDimensionSpec


class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]):
Expand All @@ -30,39 +26,49 @@ class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]):
def _implements_protocol(self) -> QueryInterfaceDimension:
return self

def __init__(self, column_name: str) -> None: # noqa
self.column_name = column_name
def __init__( # noqa
self,
name: str,
entity_path: Sequence[str],
call_parameter_sets: FilterCallParameterSets,
column_association_resolver: ColumnAssociationResolver,
) -> None:
self._dimension_spec_resolver = DimensionSpecResolver(call_parameter_sets)
self._column_association_resolver = column_association_resolver
self._name = name
self._entity_path = entity_path
self.dimension_spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path)
self.time_dimension_spec: Optional[TimeDimensionSpec] = None
Comment on lines +40 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be using the type system to enforce that exactly one of these is set? Like we could make this a union type and keep the property itself internal, and then we get some protection against someone using a time dimension as a non-time dimension, which could get odd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be using the type system to enforce that exactly one of these is set?

Do you mean something like:

self.spec: Union[DimensionSpec, TimeDimensionSpec] = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path)

I had started with that, but I think that would require the use of isinstance to disambiguate between the two types later when they are added to the lists. Paul steered me away from the use of isinstance.

protection against someone using a time dimension as a non-time dimension

Do you mean specifying grain on a non-time dimension? Or not specifying grain on a time dimension? I had thought that is validated after the specs are returned from create_from_where_filter. Is that not the case? I agree that should be checked somewhere


def grain(self, _grain: str) -> QueryInterfaceDimension:
def grain(self, time_granularity_name: str) -> QueryInterfaceDimension:
"""The time granularity."""
raise NotImplementedError
self.time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec(
self._name, TimeGranularity(time_granularity_name), self._entity_path
)
return self
Comment on lines +41 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, this has to be here because it's the input protocol. The output QueryParameter or whatever it's called has these organized a bit better. Do I have that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what you mean by this and these 😃?


def date_part(self, _date_part: str) -> QueryInterfaceDimension:
"""The date_part requested to extract."""
raise NotImplementedError

def alias(self, _alias: str) -> QueryInterfaceDimension:
"""Renaming the column."""
raise NotImplementedError
raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter")

def descending(self, _is_descending: bool) -> QueryInterfaceDimension:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax(
"Can't set descending in the where clause. Try setting descending in the order_by clause instead"
)
raise InvalidQuerySyntax("descending is invalid in the where parameter")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. Seems like the protocol specs need more refinement, but this'll have to do for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the spec is fine. We don't have an LSP for the Jinja syntax, so I think providing a more specific error message will help our users more than a generic type error.


def __str__(self) -> str:
"""Returns the column name.

Important in the Jinja sandbox.
"""
return self.column_name
return self._column_association_resolver.resolve_spec(
self.time_dimension_spec or self.dimension_spec
).column_name


class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]):
"""Creates a WhereFilterDimension.

Each call to `create` adds a DimensionSpec to dimension_specs.
Each call to `create` adds a WhereFilterDimension to created.
"""

@override
Expand All @@ -76,29 +82,12 @@ def __init__( # noqa
):
self._call_parameter_sets = call_parameter_sets
self._column_association_resolver = column_association_resolver
self.dimension_specs: List[DimensionSpec] = []
self.created: List[WhereFilterDimension] = []

def create(self, name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension:
"""Create a WhereFilterDimension."""
structured_name = DunderedNameFormatter.parse_name(name)
call_parameter_set = DimensionCallParameterSet(
dimension_reference=DimensionReference(element_name=structured_name.element_name),
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
)
assert call_parameter_set in self._call_parameter_sets.dimension_call_parameter_sets

dimension_spec = self._convert_to_dimension_spec(call_parameter_set)
self.dimension_specs.append(dimension_spec)
column_name = self._column_association_resolver.resolve_spec(dimension_spec).column_name
return WhereFilterDimension(column_name)

def _convert_to_dimension_spec(
self,
parameter_set: DimensionCallParameterSet,
) -> DimensionSpec: # noqa: D
return DimensionSpec(
element_name=parameter_set.dimension_reference.element_name,
entity_links=parameter_set.entity_path,
dimension = WhereFilterDimension(
name, entity_path, self._call_parameter_sets, self._column_association_resolver
)
self.created.append(dimension)
return dimension
34 changes: 8 additions & 26 deletions metricflow/specs/where_filter_time_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

from typing import List, Optional, Sequence

from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets, TimeDimensionCallParameterSet
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.references import EntityReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from dbt_semantic_interfaces.type_enums import TimeGranularity
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
from metricflow.protocols.query_interface import QueryInterfaceTimeDimension, QueryInterfaceTimeDimensionFactory
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver
from metricflow.specs.specs import TimeDimensionSpec


Expand Down Expand Up @@ -50,6 +49,7 @@ def __init__( # noqa
):
self._call_parameter_sets = call_parameter_sets
self._column_association_resolver = column_association_resolver
self._dimension_spec_resolver = DimensionSpecResolver(call_parameter_sets)
self.time_dimension_specs: List[TimeDimensionSpec] = []

def create(
Expand All @@ -67,27 +67,9 @@ def create(
)
if date_part_name:
raise InvalidQuerySyntax("date_part_name isn't currently supported in the where parameter")
structured_name = DunderedNameFormatter.parse_name(time_dimension_name)
call_parameter_set = TimeDimensionCallParameterSet(
time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name),
time_granularity=TimeGranularity(time_granularity_name),
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec(
time_dimension_name, TimeGranularity(time_granularity_name), entity_path
)
assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets

time_dimension_spec = self._convert_to_time_dimension_spec(call_parameter_set)
self.time_dimension_specs.append(time_dimension_spec)
column_names = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name
return WhereFilterTimeDimension(column_names)

def _convert_to_time_dimension_spec(
self,
parameter_set: TimeDimensionCallParameterSet,
) -> TimeDimensionSpec: # noqa: D
return TimeDimensionSpec(
element_name=parameter_set.time_dimension_reference.element_name,
entity_links=parameter_set.entity_path,
time_granularity=parameter_set.time_granularity,
)
column_name = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name
return WhereFilterTimeDimension(column_name)
13 changes: 12 additions & 1 deletion metricflow/specs/where_filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,22 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec
f"Error while rendering Jinja template:\n{where_filter.where_sql_template}"
) from e

"""
Dimensions that are created with a grain parameter, Dimension(...).grain(...), are
added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs
"""
dimension_specs = []
for dimension in dimension_factory.created:
if dimension.time_dimension_spec:
time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec)
else:
dimension_specs.append(dimension.dimension_spec)

return WhereFilterSpec(
where_sql=rendered_sql_template,
bind_parameters=self._bind_parameters,
linkable_spec_set=LinkableSpecSet(
dimension_specs=tuple(dimension_factory.dimension_specs),
dimension_specs=tuple(dimension_specs),
time_dimension_specs=tuple(time_dimension_factory.time_dimension_specs),
entity_specs=tuple(entity_factory.entity_specs),
),
Expand Down
25 changes: 25 additions & 0 deletions metricflow/test/model/test_where_filter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ def test_dimension_in_filter( # noqa: D
)


def test_dimension_in_filter_with_grain( # noqa: D
column_association_resolver: ColumnAssociationResolver,
) -> None:
where_filter = PydanticWhereFilter(
where_sql_template="{{ Dimension('listing__country_latest').grain('WEEK') }} = 'US'"
)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(where_filter)

assert where_filter_spec.where_sql == "listing__country_latest__week = 'US'"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(
TimeDimensionSpec(
element_name="country_latest",
entity_links=(EntityReference(element_name="listing"),),
time_granularity=TimeGranularity.WEEK,
),
),
entity_specs=(),
)


def test_time_dimension_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver,
) -> None:
Expand Down
11 changes: 0 additions & 11 deletions metricflow/test/specs/test_where_filter_dimension.py

This file was deleted.