Skip to content

Commit

Permalink
Update MF to support breaking changes in DSI for custom calendar (#1522)
Browse files Browse the repository at this point in the history
## Context

We merged 2 breaking changes in DSI
dbt-labs/dbt-semantic-interfaces#363 and
dbt-labs/dbt-semantic-interfaces#365 which
changed most spec typing that used time granularity to be a `str`
instead of `TimeGranularity` to enable support for custom granularity.
Similarly, there were additional breaking changes to the objects that
requires passing in `custom_granularity_names`. This PR updates all
those callsites to be compatible with the new version of DSI (to be
released)

Resolves SL-3097
  • Loading branch information
WilliamDee authored Dec 5, 2024
1 parent 5bb1351 commit 5cd5d40
Show file tree
Hide file tree
Showing 45 changed files with 597 additions and 461 deletions.
2 changes: 1 addition & 1 deletion extra-hatch-configuration/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Jinja2>=3.1.3
dbt-semantic-interfaces==0.7.2
dbt-semantic-interfaces==0.8.3
more-itertools>=8.10.0, <10.2.0
pydantic>=1.10.0, <3.0
tabulate>=0.8.9
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Always support a range of production DSI versions capped at the next breaking version in metricflow-semantics.
# This allows us to sync new, non-breaking changes to dbt-core without getting a version mismatch in dbt-mantle,
# which depends on a specific commit of DSI.
dbt-semantic-interfaces>=0.7.2, <0.8.0
dbt-semantic-interfaces>=0.8.3, <0.9.0
graphviz>=0.18.2, <0.21
python-dateutil>=2.9.0, <2.10.0
rapidfuzz>=3.0, <4.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import Optional

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity


def error_if_not_standard_grain(input_granularity: str, context: Optional[str] = None) -> TimeGranularity:
"""Cast input grainularity string to TimeGranularity, otherwise error.
TODO: Not needed once, custom grain is supported for most things.
"""
try:
time_grain = TimeGranularity(input_granularity)
except ValueError:
error_msg = f"Received a non-standard time granularity, which is not supported at the moment, received: {input_granularity}."
if context:
error_msg += f"\nContext: {context}"
raise ValueError(error_msg)
return time_grain
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class DunderNamingScheme(QueryItemNamingScheme):
"""A naming scheme using the dundered name syntax.
TODO: Consolidate with StructuredLinkableSpecName / DunderedNameFormatter.
TODO: Consolidate with StructuredLinkableSpecName.
"""

_INPUT_REGEX = re.compile(r"\A[a-z]([a-z0-9_])*[a-z0-9]\Z")
Expand Down Expand Up @@ -52,7 +52,7 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:

@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> EntityLinkPattern:
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise ValueError(f"{repr(input_str)} does not follow this scheme.")

input_str = input_str.lower()
Expand Down Expand Up @@ -119,7 +119,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
)

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
# This naming scheme is case-insensitive.
input_str = input_str.lower()
if DunderNamingScheme._INPUT_REGEX.match(input_str) is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import lru_cache
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

Expand Down Expand Up @@ -106,6 +107,11 @@ def date_part_suffix(date_part: DatePart) -> str:
"""Suffix used for names with a date_part."""
return f"extract_{date_part.value}"

@property
def entity_links(self) -> Tuple[EntityReference, ...]:
"""Returns the entity link references."""
return tuple(EntityReference(entity_link_name.lower()) for entity_link_name in self.entity_link_names)

@property
def granularity_free_qualified_name(self) -> str:
"""Renders the qualified name without the granularity suffix.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:
@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> MetricSpecPattern:
input_str = input_str.lower()
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise RuntimeError(f"{repr(input_str)} does not follow this scheme.")
return MetricSpecPattern(metric_reference=MetricReference(element_name=input_str))

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
# TODO: Use regex.
return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
pass

@abstractmethod
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
"""Returns true if the given input string follows this naming scheme.
Consider adding a structured result that indicates why it does not match the scheme.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:

@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> SpecPattern:
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise ValueError(
f"The specified input {repr(input_str)} does not match the input described by the object builder "
f"pattern."
)
try:
# TODO: Update when more appropriate parsing libraries are available.
call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets
call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets(
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names
)
except ParseWhereFilterException as e:
raise ValueError(f"A spec pattern can't be generated from the input string {repr(input_str)}") from e

Expand Down Expand Up @@ -121,11 +123,14 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
raise RuntimeError("There should have been a return associated with one of the CallParameterSets.")

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
if ObjectBuilderNamingScheme._NAME_REGEX.match(input_str) is None:
return False
try:
call_parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{ " + input_str + " }}")
call_parameter_sets = WhereFilterParser.parse_call_parameter_sets(
where_sql_template="{{ " + input_str + " }}",
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names,
)
return_value = (
len(call_parameter_sets.dimension_call_parameter_sets)
+ len(call_parameter_sets.time_dimension_call_parameter_sets)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.errors.custom_grain_not_supported import error_if_not_standard_grain
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_dict
Expand Down Expand Up @@ -401,7 +402,13 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> PushDownRe

# If time granularity is not set for the metric, defaults to DAY if available, else the smallest available granularity.
# Note: ignores any granularity set on input metrics.
metric_default_time_granularity = metric_to_use_for_time_granularity_resolution.time_granularity or max(
metric_time_granularity: Optional[TimeGranularity] = None
if metric_to_use_for_time_granularity_resolution.time_granularity is not None:
metric_time_granularity = error_if_not_standard_grain(
context=f"Metric({metric_to_use_for_time_granularity_resolution}).time_granularity",
input_granularity=metric_to_use_for_time_granularity_resolution.time_granularity,
)
metric_default_time_granularity = metric_time_granularity or max(
TimeGranularity.DAY,
self._semantic_manifest_lookup.metric_lookup.get_min_queryable_time_granularity(
MetricReference(metric_to_use_for_time_granularity_resolution.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def _resolve_specs_for_where_filters(
for location, where_filters in where_filters_and_locations.items():
for where_filter in where_filters:
try:
filter_call_parameter_sets = where_filter.call_parameter_sets
filter_call_parameter_sets = where_filter.call_parameter_sets(
custom_granularity_names=self._manifest_lookup.semantic_model_lookup.custom_granularity_names
)
except Exception as e:
non_parsable_resolutions.append(
NonParsableFilterResolution(
Expand Down
16 changes: 12 additions & 4 deletions metricflow-semantics/metricflow_semantics/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def _parse_order_by_names(
order_by_name_without_prefix = order_by_name

for group_by_item_naming_scheme in self._group_by_item_naming_schemes:
if group_by_item_naming_scheme.input_str_follows_scheme(order_by_name_without_prefix):
if group_by_item_naming_scheme.input_str_follows_scheme(
order_by_name_without_prefix, semantic_manifest_lookup=self._manifest_lookup
):
possible_inputs.append(
ResolverInputForGroupByItem(
input_obj=order_by_name,
Expand All @@ -223,7 +225,9 @@ def _parse_order_by_names(
break

for metric_naming_scheme in self._metric_naming_schemes:
if metric_naming_scheme.input_str_follows_scheme(order_by_name_without_prefix):
if metric_naming_scheme.input_str_follows_scheme(
order_by_name_without_prefix, semantic_manifest_lookup=self._manifest_lookup
):
possible_inputs.append(
ResolverInputForMetric(
input_obj=order_by_name,
Expand Down Expand Up @@ -373,7 +377,9 @@ def _parse_and_validate_query(
for metric_name in metric_names:
resolver_input_for_metric: Optional[MetricFlowQueryResolverInput] = None
for metric_naming_scheme in self._metric_naming_schemes:
if metric_naming_scheme.input_str_follows_scheme(metric_name):
if metric_naming_scheme.input_str_follows_scheme(
metric_name, semantic_manifest_lookup=self._manifest_lookup
):
resolver_input_for_metric = ResolverInputForMetric(
input_obj=metric_name,
naming_scheme=metric_naming_scheme,
Expand Down Expand Up @@ -405,7 +411,9 @@ def _parse_and_validate_query(
for group_by_name in group_by_names:
resolver_input_for_group_by_item: Optional[MetricFlowQueryResolverInput] = None
for group_by_item_naming_scheme in self._group_by_item_naming_schemes:
if group_by_item_naming_scheme.input_str_follows_scheme(group_by_name):
if group_by_item_naming_scheme.input_str_follows_scheme(
group_by_name, semantic_manifest_lookup=self._manifest_lookup
):
spec_pattern = group_by_item_naming_scheme.spec_pattern(
group_by_name, semantic_manifest_lookup=self._manifest_lookup
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
DimensionCallParameterSet,
TimeDimensionCallParameterSet,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import (
QueryInterfaceDimension,
Expand All @@ -18,6 +17,7 @@
from typing_extensions import override

from metricflow_semantics.errors.error_classes import InvalidQuerySyntax
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
Expand Down Expand Up @@ -134,15 +134,19 @@ def __init__( # noqa
spec_resolution_lookup: FilterSpecResolutionLookUp,
where_filter_location: WhereFilterLocation,
rendered_spec_tracker: RenderedSpecTracker,
custom_granularity_names: Sequence[str],
):
self._column_association_resolver = column_association_resolver
self._resolved_spec_lookup = spec_resolution_lookup
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker
self._custom_granularity_names = custom_granularity_names

def create(self, name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension:
"""Create a WhereFilterDimension."""
structured_name = DunderedNameFormatter.parse_name(name.lower())
structured_name = StructuredLinkableSpecName.from_name(
name.lower(), custom_granularity_names=self._custom_granularity_names
)

return WhereFilterDimension(
column_association_resolver=self._column_association_resolver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from dbt_semantic_interfaces.call_parameter_sets import (
EntityCallParameterSet,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import QueryInterfaceEntity, QueryInterfaceEntityFactory
from dbt_semantic_interfaces.references import EntityReference
Expand All @@ -14,6 +13,7 @@
from typing_extensions import override

from metricflow_semantics.errors.error_classes import InvalidQuerySyntax
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow_semantics.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__( # noqa

def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> WhereFilterEntity:
"""Create a WhereFilterEntity."""
structured_name = DunderedNameFormatter.parse_name(entity_name.lower())
structured_name = StructuredLinkableSpecName.from_name(entity_name.lower(), custom_granularity_names=())

return WhereFilterEntity(
column_association_resolver=self._column_association_resolver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def create_from_where_filter_intersection( # noqa: D102
spec_resolution_lookup=self._spec_resolution_lookup,
where_filter_location=filter_location,
rendered_spec_tracker=rendered_spec_tracker,
custom_granularity_names=self._semantic_model_lookup.custom_granularity_names,
)
time_dimension_factory = WhereFilterTimeDimensionFactory(
column_association_resolver=self._column_association_resolver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,16 @@ metric:
- name: bookings_offset_once
offset_window: 2 days
---
metric:
name: "bookings_offset_martian_day"
description: bookings metric offset by a martian day.
type: derived
type_params:
expr: 2 * bookings
metrics:
- name: bookings
offset_window: 1 martian_day
---
metric:
name: bookings_at_start_of_month
description: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging

from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow
from dbt_semantic_interfaces.references import MetricReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
Expand All @@ -21,3 +22,18 @@ def test_min_queryable_time_granularity_for_different_agg_time_grains( # noqa:
# Since `monthly_bookings_to_daily_bookings` is based on metrics with DAY and MONTH aggregation time grains,
# the minimum queryable grain should be MONTH.
assert min_queryable_grain == TimeGranularity.MONTH


def test_custom_offset_window_for_metric(
simple_semantic_manifest_lookup: SemanticManifestLookup,
) -> None:
"""Test offset window with custom grain supplied.
TODO: As of now, the functionality of an offset window with a custom grain is not supported in MF.
This test is added to show that at least the parsing is successful using a custom grain offset window.
Once support for that is added in MF + relevant tests, this test can be removed.
"""
metric = simple_semantic_manifest_lookup.metric_lookup.get_metric(MetricReference("bookings_offset_martian_day"))

assert len(metric.input_metrics) == 1
assert metric.input_metrics[0].offset_window == PydanticMetricTimeWindow(count=1, granularity="martian_day")
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,28 @@ def test_input_str(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D
)


def test_input_follows_scheme(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D103
assert dunder_naming_scheme.input_str_follows_scheme("listing__country")
assert dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__month")
assert dunder_naming_scheme.input_str_follows_scheme("booking__listing")
assert not dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month")
assert not dunder_naming_scheme.input_str_follows_scheme("123")
assert not dunder_naming_scheme.input_str_follows_scheme("TimeDimension('metric_time')")
def test_input_follows_scheme( # noqa: D103
dunder_naming_scheme: DunderNamingScheme,
simple_semantic_manifest_lookup: SemanticManifestLookup,
) -> None:
assert dunder_naming_scheme.input_str_follows_scheme(
"listing__country", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert dunder_naming_scheme.input_str_follows_scheme(
"listing__creation_time__month", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert dunder_naming_scheme.input_str_follows_scheme(
"booking__listing", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"listing__creation_time__extract_month", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"123", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"TimeDimension('metric_time')", semantic_manifest_lookup=simple_semantic_manifest_lookup
)


def test_spec_pattern( # noqa: D103
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def test_input_str(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D
assert metric_naming_scheme.input_str(MetricSpec(element_name="bookings")) == "bookings"


def test_input_follows_scheme(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D103
assert metric_naming_scheme.input_str_follows_scheme("listings")
def test_input_follows_scheme( # noqa: D103
metric_naming_scheme: MetricNamingScheme, simple_semantic_manifest_lookup: SemanticManifestLookup
) -> None:
assert metric_naming_scheme.input_str_follows_scheme(
"listings", semantic_manifest_lookup=simple_semantic_manifest_lookup
)


def test_spec_pattern( # noqa: D103
Expand Down
Loading

0 comments on commit 5cd5d40

Please sign in to comment.