From beb6f76c4b8cd95a48894f737b36933267e28425 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 25 Sep 2023 16:33:01 -0700 Subject: [PATCH] Checkpoint. --- metricflow/query/query_issues.py | 58 ++++ metricflow/query/query_resolver.py | 298 ++++++++++++++++++ metricflow/specs/patterns/__init__.py | 0 metricflow/specs/patterns/dunder_scheme.py | 159 ++++++++++ .../specs/patterns/entity_path_pattern.py | 226 +++++++++++++ metricflow/specs/patterns/metric_pattern.py | 65 ++++ metricflow/specs/patterns/python_scheme.py | 133 ++++++++ metricflow/specs/patterns/similarity.py | 48 +++ metricflow/specs/patterns/spec_pattern.py | 126 ++++++-- .../test/specs/test_entity_path_pattern.py | 223 +++++++++++++ 10 files changed, 1307 insertions(+), 29 deletions(-) create mode 100644 metricflow/query/query_resolver.py create mode 100644 metricflow/specs/patterns/__init__.py create mode 100644 metricflow/specs/patterns/entity_path_pattern.py create mode 100644 metricflow/specs/patterns/metric_pattern.py create mode 100644 metricflow/specs/patterns/python_scheme.py create mode 100644 metricflow/specs/patterns/similarity.py create mode 100644 metricflow/test/specs/test_entity_path_pattern.py diff --git a/metricflow/query/query_issues.py b/metricflow/query/query_issues.py index e69de29bb2..2eb0356a0e 100644 --- a/metricflow/query/query_issues.py +++ b/metricflow/query/query_issues.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Sequence + +from dbt_semantic_interfaces.references import MetricReference +from typing_extensions import override + + +class MetricFlowQueryIssueType(Enum): + """Errors prevent the query from running, where warnings do not.""" + + WARNING = "WARNING" + ERROR = "ERROR" + + +@dataclass(frozen=True) +class MetricFlowQueryResolutionIssue: + """An issue in the query that should be resolved.""" + + issue_type: MetricFlowQueryIssueType + message: str + metric_path: Sequence[MetricReference] = () + + @override + def __str__(self) -> str: + if not self.metric_path: + return f"{self.issue_type.value} - {self.message}" + readable_metric_path = str([metric_reference.element_name for metric_reference in self.metric_path]) + return f"{self.issue_type.value} - {readable_metric_path} - {self.message}" + + def with_additional_metric_path_prefix( # noqa: D + self, metric_reference: MetricReference + ) -> MetricFlowQueryResolutionIssue: + return MetricFlowQueryResolutionIssue( + issue_type=self.issue_type, message=self.message, metric_path=(metric_reference,) + tuple(self.metric_path) + ) + + +@dataclass(frozen=True) +class MetricFlowQueryIssueSet: + """The result of resolving query inputs to specs.""" + + issues: Sequence[MetricFlowQueryResolutionIssue] = () + + def with_additional_metric_path_prefix(self, metric_reference: MetricReference) -> MetricFlowQueryIssueSet: + """Return a new issue set where the existing metric paths are prefixed with the given metric.""" + return MetricFlowQueryIssueSet( + issues=tuple(issue.with_additional_metric_path_prefix(metric_reference) for issue in self.issues) + ) + + def merge(self, other: MetricFlowQueryIssueSet) -> MetricFlowQueryIssueSet: # noqa: D + return MetricFlowQueryIssueSet(issues=tuple(self.issues) + tuple(other.issues)) + + @property + def errors(self) -> Sequence[MetricFlowQueryResolutionIssue]: # noqa: D + return tuple(issue for issue in self.issues if issue.issue_type is MetricFlowQueryIssueType.ERROR) diff --git a/metricflow/query/query_resolver.py b/metricflow/query/query_resolver.py new file mode 100644 index 0000000000..beac7802ff --- /dev/null +++ b/metricflow/query/query_resolver.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import textwrap +from dataclasses import dataclass +from typing import Optional, Sequence, Dict + +import rapidfuzz +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.pretty_print import pformat_big_objects +from dbt_semantic_interfaces.protocols import WhereFilter +from dbt_semantic_interfaces.references import MetricReference +from dbt_semantic_interfaces.type_enums import MetricType +from typing_extensions import override + +from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup +from metricflow.query.query_issues import ( + MetricFlowQueryIssueSet, + MetricFlowQueryIssueType, + MetricFlowQueryResolutionIssue, +) +from metricflow.specs.column_assoc import ColumnAssociationResolver +from metricflow.specs.merge_builder import Mergeable, MergeBuilder +from metricflow.specs.patterns.metric_pattern import MetricNamePattern, MetricNamingScheme +from metricflow.specs.patterns.spec_pattern import ScoredSpec, ScoringResults, SpecPattern +from metricflow.specs.specs import MetricFlowQuerySpec, MetricSpec, OrderBySpec, LinkableInstanceSpec, WhereFilterSpec + + +@dataclass(frozen=True) +class MetricFlowQueryResolution(Mergeable): + """The result of resolving query inputs to specs.""" + + # Can be None if there were errors. + query_spec: Optional[MetricFlowQuerySpec] = None + + metric_matches: Dict[SpecPattern[MetricSpec], Sequence[MetricSpec]] + group_by_item_matches: Dict[SpecPattern[LinkableInstanceSpec], Sequence[LinkableInstanceSpec]] + where_filter_spec: Optional[WhereFilterSpec] + + issue_set: MetricFlowQueryIssueSet = MetricFlowQueryIssueSet() + + def checked_query_spec(self) -> MetricFlowQuerySpec: + """Returns the query_spec, but if MetricFlowQueryResolution.has_errors was True, raise a RuntimeError.""" + if self.has_errors: + raise RuntimeError( + f"Can't get the query spec because errors were present in the resolution:" + f"\n{pformat_big_objects(self.issue_set.errors)}" + ) + if self.query_spec is None: + raise RuntimeError("If there were no errors, query_spec should have been populated.") + return self.query_spec + + def with_additional_issues( + self, + issue_set: MetricFlowQueryIssueSet, + ) -> MetricFlowQueryResolution: + """Return a new resolution with those issues added.""" + return MetricFlowQueryResolution(query_spec=self.query_spec, issue_set=self.issue_set.merge(issue_set)) + + @property + def has_errors(self) -> bool: # noqa: D + return len(self.issue_set.errors) > 0 + + @override + def merge(self, other: MetricFlowQueryResolution) -> MetricFlowQueryResolution: + merge_builder = MergeBuilder(MetricFlowQuerySpec()) + if self.query_spec is not None: + merge_builder.add(self.query_spec) + if other.query_spec is not None: + merge_builder.add(other.query_spec) + + return MetricFlowQueryResolution( + query_spec=merge_builder.build_result, issue_set=self.issue_set.merge(other.issue_set) + ) + + def with_additional_metric_path_prefix(self, metric_reference: MetricReference) -> MetricFlowQueryResolution: + """Return a resolution where the metric is added as the first element in the issues' metric path.""" + return MetricFlowQueryResolution( + query_spec=self.query_spec, issue_set=self.issue_set.with_additional_metric_path_prefix(metric_reference) + ) + + +@dataclass(frozen=True) +class MetricFlowQueryOrderByItem: + """Describes the order direction for one of the metrics or group by items.""" + + spec_pattern: SpecPattern + descending: bool + + +class MetricFlowQueryResolver: + """Given spec patterns that define the query, resolve them into specs. + + TODO: WhereSpecFactory should not need to depend on the column association resolver. + """ + + _INDENT = " " + + def __init__( # noqa: D + self, + manifest_lookup: SemanticManifestLookup, + column_association_resolver: ColumnAssociationResolver, + ) -> None: + self._manifest_lookup = manifest_lookup + self._column_association_resolver = column_association_resolver + self._known_metric_specs = [ + MetricSpec.from_reference(metric_reference) + for metric_reference in self._manifest_lookup.metric_lookup.metric_references + ] + + @staticmethod + def _create_error_resolution_for_unmatched_pattern( + plural_item_name: str, + spec_pattern: SpecPattern, + scoring_results: ScoringResults, + top_n_suggestions: int = 6, + ) -> MetricFlowQueryResolution: + ranked_specs: Sequence[ScoredSpec] = sorted( + scoring_results.scored_specs, key=lambda scored_spec: scored_spec.score + ) + + suggestions = pformat_big_objects( + [spec_pattern.naming_scheme.input_str(scored_spec.spec) for scored_spec in ranked_specs[:top_n_suggestions]] + ) + return MetricFlowQueryResolution( + issue_set=MetricFlowQueryIssueSet( + issues=( + MetricFlowQueryResolutionIssue( + issue_type=MetricFlowQueryIssueType.ERROR, + message=( + f"`{spec_pattern}` does not match exactly to one of the available " + f"{plural_item_name}. Suggestions:\n" + f"{textwrap.indent(suggestions, prefix=MetricFlowQueryResolver._INDENT)}" + ), + ), + ) + ) + ) + + def resolve_query( # noqa: D + self, + metric_patterns: Sequence[SpecPattern], + group_by_item_patterns: Sequence[SpecPattern], + order_by_items: Sequence[MetricFlowQueryOrderByItem], + limit: Optional[int], + where_filter: Optional[WhereFilter], + ) -> MetricFlowQueryResolution: + query_resolution_builder = MergeBuilder(MetricFlowQueryResolution()) + + for metric_pattern in metric_patterns: + query_resolution_builder.add( + self._resolve_query_for_one_metric( + metric_pattern=metric_pattern, + group_by_item_patterns=group_by_item_patterns, + where_spec=where_filter, + ) + ) + + # Check the order by if there is a query spec with metrics and group by items. + query_spec = query_resolution_builder.build_result.query_spec + if query_spec is not None: + query_item_specs = query_spec.metric_specs + query_spec.linkable_specs.as_tuple + + # Check that the patterns in the order by match with one of the specs specified in the query. + order_by_specs = [] + for order_by_item in order_by_items: + scoring_results = order_by_item.spec_pattern.score(query_item_specs) + if not scoring_results.has_exactly_one_match: + query_resolution_builder.add( + self._create_error_resolution_for_unmatched_pattern( + plural_item_name="query items", + spec_pattern=order_by_item.spec_pattern, + scoring_results=scoring_results, + ) + ) + else: + order_by_specs.append( + OrderBySpec( + instance_spec=scoring_results.matching_spec, + descending=order_by_item.descending, + ) + ) + + query_resolution_builder.add( + MetricFlowQueryResolution(query_spec=MetricFlowQuerySpec(order_by_specs=tuple(order_by_specs))) + ) + + # Add an issue if the limit is negative. + if limit is not None and limit < 0: + query_resolution_builder.add( + MetricFlowQueryResolution( + issue_set=MetricFlowQueryIssueSet( + issues=( + MetricFlowQueryResolutionIssue( + issue_type=MetricFlowQueryIssueType.ERROR, + message=f"The limit was specified as {limit}, but it must be >= 0.", + ), + ) + ) + ) + ) + + query_resolution: MetricFlowQueryResolution = query_resolution_builder.build_result + + # Return the resolution without the query spec if there are errors. + if query_resolution.has_errors: + return MetricFlowQueryResolution(issue_set=query_resolution.issue_set) + + # If there are no errors, there should be a spec and the resolution should be ready to return. + if query_resolution.query_spec is None: + raise RuntimeError("The query spec is missing even though there are no errors in the query.") + + return query_resolution + + def _resolve_query_for_one_metric( # noqa: D + self, + metric_pattern: SpecPattern, + group_by_item_patterns: Sequence[SpecPattern], + where_spec: Optional[WhereFilter], + ) -> MetricFlowQueryResolution: + query_resolution_builder = MergeBuilder(MetricFlowQueryResolution()) + + # Check if the metric patterns match with known metrics. + metric_scoring_results = metric_pattern.score(self._known_metric_specs) + + if not metric_scoring_results.has_exactly_one_match: + query_resolution_builder.add( + MetricFlowQueryResolver._create_error_resolution_for_unmatched_pattern( + plural_item_name="metrics", + spec_pattern=metric_pattern, + scoring_results=metric_scoring_results, + ) + ) + return query_resolution_builder.build_result + + metric_spec_set = metric_scoring_results.matching_spec.as_spec_set + assert ( + len(metric_spec_set.metric_specs) == 1 + ), f"Did not get exactly 1 metric spec: {metric_spec_set.metric_specs}" + + metric_spec = metric_spec_set.metric_specs[0] + + metric = self._manifest_lookup.metric_lookup.get_metric(metric_spec.reference) + + if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE: + raise NotImplementedError + elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED: + parent_metrics = [input_metric for input_metric in metric.input_metrics] + + for parent_metric in parent_metrics: + metric_reference = parent_metric.as_reference + parent_metric_query_resolution = self._resolve_query_for_one_metric( + metric_pattern=MetricNamePattern( + naming_scheme=MetricNamingScheme(), + target_spec=MetricSpec.from_reference(metric_reference), + input_str=metric_reference.element_name, + ), + group_by_item_patterns=group_by_item_patterns, + where_spec=where_spec, + ) + query_resolution_builder.add( + parent_metric_query_resolution.with_additional_metric_path_prefix(metric_reference), + ) + else: + assert_values_exhausted(metric.type) + + # Return early as it's difficult to resolve group by items without correct metrics. + if query_resolution_builder.build_result.has_errors > 0: + return query_resolution_builder.build_result + + # Check that the group by items match to one that's available. + possible_group_by_specs = self._manifest_lookup.metric_lookup.element_specs_for_metrics( + metric_references=tuple(metric_spec.reference for metric_spec in metric_spec_set.metric_specs) + ) + + # Build a spec set by matching the patterns for the group by items to the available group by item specs + # for the queried metrics. + for group_by_item_pattern in group_by_item_patterns: + group_by_item_scoring_results = group_by_item_pattern.score(possible_group_by_specs) + if not group_by_item_scoring_results.has_exactly_one_match: + query_resolution_builder.add( + MetricFlowQueryResolver._create_error_resolution_for_unmatched_pattern( + plural_item_name="group by items", + spec_pattern=group_by_item_pattern, + scoring_results=group_by_item_scoring_results, + ) + ) + + else: + matching_spec = group_by_item_scoring_results.matching_spec + query_resolution_builder.add( + MetricFlowQueryResolution(query_spec=MetricFlowQuerySpec.from_spec_set(matching_spec.as_spec_set)) + ) + + # Check the filter for the metric + if metric.filter is not None: + raise NotImplementedError + + return query_resolution_builder.build_result diff --git a/metricflow/specs/patterns/__init__.py b/metricflow/specs/patterns/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metricflow/specs/patterns/dunder_scheme.py b/metricflow/specs/patterns/dunder_scheme.py index e69de29bb2..8f25dabedc 100644 --- a/metricflow/specs/patterns/dunder_scheme.py +++ b/metricflow/specs/patterns/dunder_scheme.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from typing import Optional + +from dbt_semantic_interfaces.naming.keywords import DUNDER +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from typing_extensions import override + +from metricflow.specs.patterns.entity_path_pattern import EntityPathPattern, EntityPathPatternParameterSet +from metricflow.specs.patterns.spec_pattern import QueryItemNamingScheme, SpecPattern +from metricflow.specs.specs import InstanceSpecSet, InstanceSpecSetTransform, LinkableInstanceSpec +from metricflow.time.date_part import DatePart + + +class DunderNamingScheme(QueryItemNamingScheme[LinkableInstanceSpec]): + """A naming scheme that mirrors the behavior of StructuredLinkableSpecName. + + TODO: Replace StructuredLinkableSpecName with this. + + * See input_str_description(). + * The behavior in StructuredLinkableSpecName when a TimeDimensionSpec has a date part is nuanced. When a + TimeDimensionSpec has a date_part, a column name can be formed. However, an input string cannot contain a + date part. + """ + + @staticmethod + def _date_part_suffix(date_part: DatePart) -> str: + """Suffix used for names with a date_part.""" + return f"extract_{date_part.value}" + + @override + def input_str(self, instance_spec: LinkableInstanceSpec) -> Optional[str]: + spec_set = instance_spec.as_spec_set + + for time_dimension_spec in spec_set.time_dimension_specs: + # From existing comment in StructuredLinkableSpecName: + # + # Dunder syntax not supported for querying date_part + # + if time_dimension_spec.date_part is not None: + return None + return self.output_column_name(instance_spec) + + @override + def output_column_name(self, instance_spec: LinkableInstanceSpec) -> str: + return _DunderNameTransform().transform(instance_spec.as_spec_set) + + @override + def spec_pattern(self, input_str: str) -> SpecPattern[LinkableInstanceSpec]: + if not self.input_str_follows_scheme(input_str): + raise RuntimeError(f"`{input_str}` does not follow this scheme.") + input_str_parts = input_str.split(DUNDER) + + # No dunder, e.g. "ds" + if len(input_str_parts) == 1: + return EntityPathPattern( + parameter_set=EntityPathPatternParameterSet( + element_name=input_str_parts[0], + entity_links=(), + time_granularity=None, + date_part=None, + input_string=input_str, + naming_scheme=self, + ) + ) + + associated_granularity = None + for granularity in TimeGranularity: + if input_str_parts[-1] == granularity.value: + associated_granularity = granularity + + # Has a time granularity + if associated_granularity is not None: + # e.g. "ds__month" + if len(input_str_parts) == 2: + return EntityPathPattern( + parameter_set=EntityPathPatternParameterSet( + element_name=input_str_parts[0], + entity_links=(), + time_granularity=associated_granularity, + date_part=None, + input_string=input_str, + naming_scheme=self, + ) + ) + # e.g. "messages__ds__month" + return EntityPathPattern( + parameter_set=EntityPathPatternParameterSet( + element_name=input_str_parts[-2], + entity_links=tuple(EntityReference(entity_name) for entity_name in input_str_parts[:-2]), + time_granularity=associated_granularity, + date_part=None, + input_string=input_str, + naming_scheme=self, + ) + ) + + # e.g. "messages__ds" + else: + return EntityPathPattern( + parameter_set=EntityPathPatternParameterSet( + element_name=input_str_parts[-1], + entity_links=tuple(EntityReference(entity_name) for entity_name in input_str_parts[:-1]), + time_granularity=None, + date_part=None, + input_string=input_str, + naming_scheme=self, + ) + ) + + @override + def input_str_follows_scheme(self, input_str: str) -> bool: + input_str_parts = input_str.split(DUNDER) + + for date_part in DatePart: + if input_str_parts[-1] == DunderNamingScheme._date_part_suffix(date_part=date_part): + # From existing message in StructuredLinkableSpecName: "Dunder syntax not supported for querying + # date_part". + return False + + return True + + @property + @override + def input_str_description(self) -> str: + return ( + "The input string should be a sequence of strings consisting of the entity links, the name of the " + "dimension or entity, and a time granularity (if applicable), joined by a double underscore. e.g. " + "listing__user__country or metric_time__day." + ) + + +class _DunderNameTransform(InstanceSpecSetTransform[str]): + """Transforms group by item specs into the appropriate string.""" + + @override + def transform(self, spec_set: InstanceSpecSet) -> str: + assert len(spec_set.measure_specs) == 0 + assert len(spec_set.metric_specs) == 0 + assert len(spec_set.metadata_specs) == 0 + + for time_dimension_spec in spec_set.time_dimension_specs: + items = list(entity_link.element_name for entity_link in time_dimension_spec.entity_links) + [ + time_dimension_spec.element_name + ] + if time_dimension_spec.date_part is not None: + items.append(DunderNamingScheme._date_part_suffix(date_part=time_dimension_spec.date_part)) + else: + items.append(time_dimension_spec.time_granularity.value) + return DUNDER.join(items) + + for other_group_by_item_specs in spec_set.entity_specs + spec_set.dimension_specs: + items = list(entity_link.element_name for entity_link in other_group_by_item_specs.entity_links) + [ + other_group_by_item_specs.element_name + ] + return DUNDER.join(items) + + raise RuntimeError(f"Did not find any appropriate specs in {spec_set}") diff --git a/metricflow/specs/patterns/entity_path_pattern.py b/metricflow/specs/patterns/entity_path_pattern.py new file mode 100644 index 0000000000..970c04a68f --- /dev/null +++ b/metricflow/specs/patterns/entity_path_pattern.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Sequence + +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from typing_extensions import override + +from metricflow.specs.merge_builder import MergeBuilder +from metricflow.specs.patterns.spec_pattern import QueryItemNamingScheme, ScoringResults, SpecPattern +from metricflow.specs.specs import InstanceSpecSet, InstanceSpecSetTransform, LinkableInstanceSpec +from metricflow.time.date_part import DatePart + + +@dataclass(frozen=True) +class EntityPathPatternParameterSet: + """See EntityPathPattern for more details.""" + + # The name of the element in the semantic model + element_name: str + # The entities used for joining semantic models. + entity_links: Sequence[EntityReference] + # If specified, match only time dimensions with the following properties. + time_granularity: Optional[TimeGranularity] + date_part: Optional[DatePart] + + # The string that was used to specify this parameter set, and the naming scheme that was used. This is needed for + # generating suggestions in case there are no matches. + input_string: str + naming_scheme: QueryItemNamingScheme[LinkableInstanceSpec] + + +@dataclass(frozen=True) +class EntityPathPattern(SpecPattern[LinkableInstanceSpec]): + """A pattern that matches group by items using the entity link path specification. + + The generic parameter LinkableInstanceSpecT determines the types of specs that this match / score. + + The entity link path specifies how a group by item for a metric query should be constructed. The group by item + is obtained by joining the semantic model containing the measure to a semantic model containing the group by + item using a specified entity. Additional semantic models can be joined using additional entities to obtain the + group by item. The series of entities that are used is the entity path. Since the entity path does not specify + which semantic models need to be used, additional resolution is done in later stages to generate the necessary SQL. + + The logic for matching / scoring a set of specs is: + + * Look for specs that match all entity path parameters. If there are any such matches, return those and score the + rest by edit distance of the name as defined by the naming scheme. + * If the entity path parameters does not specify the time granularity / date part, but there are time dimension + specs that match the entity path and the element name, consider the spec with the finest granularity as the only + match and score the rest as above. + + The logic above follows the MF query interface. + """ + + parameter_set: EntityPathPatternParameterSet + + @override + def score(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> ScoringResults: + spec_set = MergeBuilder.merge_iterable( + initial_item=InstanceSpecSet(), + other_iterable=tuple(candidate_spec.as_spec_set for candidate_spec in candidate_specs), + ) + + return _ScoreByEntityPathTransform(self.parameter_set).transform(spec_set) + + @property + @override + def naming_scheme(self) -> QueryItemNamingScheme[LinkableInstanceSpec]: + return self.parameter_set.naming_scheme + + +@dataclass(frozen=True) +class DimensionMatchingEntityPathPattern(SpecPattern[LinkableInstanceSpec]): + """Similar to EntityPathPattern but only matches dimensions.""" + + parameter_set: EntityPathPatternParameterSet + + @override + def score(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> ScoringResults: + spec_set = MergeBuilder.merge_iterable( + initial_item=InstanceSpecSet(), + other_iterable=tuple( + InstanceSpecSet(dimension_specs=candidate_spec.as_spec_set.dimension_specs) + for candidate_spec in candidate_specs + ), + ) + + return _ScoreByEntityPathTransform(self.parameter_set).transform(spec_set) + + @property + @override + def naming_scheme(self) -> QueryItemNamingScheme[LinkableInstanceSpec]: + return self.parameter_set.naming_scheme + + +@dataclass(frozen=True) +class TimeDimensionMatchingEntityPathPattern(SpecPattern[LinkableInstanceSpec]): + """Similar to EntityPathPattern but only matches time dimensions.""" + + parameter_set: EntityPathPatternParameterSet + + @override + def score(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> ScoringResults: + spec_set = MergeBuilder.merge_iterable( + initial_item=InstanceSpecSet(), + other_iterable=tuple( + InstanceSpecSet(time_dimension_specs=candidate_spec.as_spec_set.time_dimension_specs) + for candidate_spec in candidate_specs + ), + ) + + return _ScoreByEntityPathTransform(self.parameter_set).transform(spec_set) + + @property + @override + def naming_scheme(self) -> QueryItemNamingScheme[LinkableInstanceSpec]: + return self.parameter_set.naming_scheme + + +@dataclass(frozen=True) +class EntityMatchingEntityPathPattern(SpecPattern[LinkableInstanceSpec]): + """Similar to EntityPathPattern but only matches entities.""" + + parameter_set: EntityPathPatternParameterSet + + @override + def score(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> ScoringResults: + spec_set = MergeBuilder.merge_iterable( + initial_item=InstanceSpecSet(), + other_iterable=tuple( + InstanceSpecSet(entity_specs=candidate_spec.as_spec_set.entity_specs) + for candidate_spec in candidate_specs + ), + ) + + return _ScoreByEntityPathTransform(self.parameter_set).transform(spec_set) + + @property + @override + def naming_scheme(self) -> QueryItemNamingScheme[LinkableInstanceSpec]: + return self.parameter_set.naming_scheme + + +class _ScoreByEntityPathTransform(InstanceSpecSetTransform[ScoringResults]): + """Scores specs according to the description in EntityPathPattern.""" + + def __init__( + self, + parameter_set: EntityPathPatternParameterSet, + ) -> None: # noqa: D + self._parameter_set = parameter_set + + @override + def transform(self, spec_set: InstanceSpecSet) -> ScoringResults: + assert len(spec_set.metadata_specs) == 0 + assert len(spec_set.metric_specs) == 0 + + # Check for matches. + parameter_set = self._parameter_set + matching_specs: List[LinkableInstanceSpec] = [] + all_specs = spec_set.dimension_specs + spec_set.time_dimension_specs + spec_set.entity_specs + + # TODO: Remove me. + # # If time_granularity or date_part, the only ones that can match are time dimensions. + # if parameter_set.time_granularity is not None or parameter_set.date_part is not None: + # for time_dimension_spec in spec_set.time_dimension_specs: + # if ( + # time_dimension_spec.element_name == parameter_set.element_name + # and time_dimension_spec.entity_links == parameter_set.entity_links + # and time_dimension_spec.time_granularity == parameter_set.time_granularity + # and time_dimension_spec.date_part == parameter_set.date_part + # ): + # matching_specs.append(time_dimension_spec) + # + # return SpecPattern.make_scoring_results( + # matching_specs=matching_specs, + # non_matching_specs=tuple(spec for spec in all_specs if spec not in matching_specs), + # input_str=parameter_set.input_string, + # naming_scheme=parameter_set.naming_scheme, + # ) + + # # At this point, we know the time granularity / date part was not specified. See if there's a time dimension + # # spec that could match. + # time_dimension_specs_that_could_match = [ + # time_dimension_spec + # for time_dimension_spec in spec_set.time_dimension_specs + # if time_dimension_spec.element_name == parameter_set.element_name + # and time_dimension_spec.entity_links == parameter_set.entity_links + # ] + # if len(time_dimension_specs_that_could_match) > 0: + # return SpecPattern.make_scoring_results( + # matching_specs=( + # min( + # time_dimension_specs_that_could_match, + # key=lambda candidate_spec: candidate_spec.time_granularity, + # ), + # ), + # non_matching_specs=tuple(spec for spec in all_specs if spec not in matching_specs), + # input_str=parameter_set.input_string, + # naming_scheme=parameter_set.naming_scheme, + # ) + + for spec in spec_set.time_dimension_specs: + # If the time granularity was specified but it doesn't match the granularity of the spec, then it can't + # be an exact match. + if parameter_set.time_granularity is not None and spec.time_granularity != parameter_set.time_granularity: + continue + # Likewise for the date part. + if parameter_set.date_part is not None and spec.date_part != parameter_set.date_part: + continue + + if + + # At this point, we know that no time dimension spec matches, so check the other types. + for spec in spec_set.dimension_specs + spec_set.entity_specs: + if spec.element_name == parameter_set.element_name and spec.entity_links == parameter_set.entity_links: + matching_specs.append(spec) + + return SpecPattern.make_scoring_results( + matching_specs=matching_specs, + non_matching_specs=tuple(spec for spec in all_specs if spec not in matching_specs), + input_str=parameter_set.input_string, + naming_scheme=parameter_set.naming_scheme, + ) diff --git a/metricflow/specs/patterns/metric_pattern.py b/metricflow/specs/patterns/metric_pattern.py new file mode 100644 index 0000000000..58062e3e5a --- /dev/null +++ b/metricflow/specs/patterns/metric_pattern.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +from typing_extensions import override + +from metricflow.specs.patterns.spec_pattern import QueryItemNamingScheme, ScoringResults, SpecPattern +from metricflow.specs.specs import MetricSpec + + +@dataclass(frozen=True) +class MetricNamePattern(SpecPattern[MetricSpec]): + """A pattern that matches specs on the element name.""" + + naming_scheme: QueryItemNamingScheme[MetricSpec] + target_spec: MetricSpec + input_str: str + + @override + def score(self, candidate_specs: Sequence[MetricSpec]) -> ScoringResults: + matching_specs = [] + + for candidate_spec in candidate_specs: + if candidate_spec.element_name == self.target_spec.element_name: + matching_specs.append(candidate_spec) + + return self.make_scoring_results( + matching_specs=matching_specs, + non_matching_specs=tuple(spec for spec in candidate_specs if spec not in matching_specs), + input_str=self.input_str, + naming_scheme=MetricNamingScheme(), + ) + + +class MetricNamingScheme(QueryItemNamingScheme[MetricSpec]): + """A naming scheme for metric specs using the element name.""" + + @override + def input_str(self, instance_spec: MetricSpec) -> str: + return instance_spec.element_name + + @override + def output_column_name(self, instance_spec: MetricSpec) -> str: + return instance_spec.element_name + + @override + def spec_pattern(self, input_str: str) -> SpecPattern[MetricSpec]: + if not self.input_str_follows_scheme(input_str): + raise ValueError("Can't create a pattern as the input string does not follow this scheme.") + return MetricNamePattern( + naming_scheme=self, + target_spec=MetricSpec(element_name=input_str), + input_str=input_str, + ) + + @override + def input_str_follows_scheme(self, input_str: str) -> bool: + # Could use UniqueAndValidNameRule, but needs some modifications to that class. + return True + + @property + @override + def input_str_description(self) -> str: + return "The metric input string should follow the convention for defining metric names in the configuration." diff --git a/metricflow/specs/patterns/python_scheme.py b/metricflow/specs/patterns/python_scheme.py new file mode 100644 index 0000000000..3075236558 --- /dev/null +++ b/metricflow/specs/patterns/python_scheme.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import Optional + +from dbt_semantic_interfaces.call_parameter_sets import ParseWhereFilterException +from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter +from dbt_semantic_interfaces.naming.keywords import DUNDER +from typing_extensions import override + +from metricflow.specs.patterns.entity_path_pattern import EntityPathPattern, EntityPathPatternParameterSet +from metricflow.specs.patterns.spec_pattern import QueryItemNamingScheme, SpecPattern +from metricflow.specs.specs import InstanceSpecSet, InstanceSpecSetTransform, LinkableInstanceSpec + + +class PythonObjectNamingScheme(QueryItemNamingScheme[LinkableInstanceSpec]): + """A naming scheme using Python object syntax like TimeDimension('metric_time', time_granularity_name='day').""" + + @override + def input_str(self, instance_spec: LinkableInstanceSpec) -> Optional[str]: + return _PythonObjectNameTransform().transform(instance_spec.as_spec_set) + + @override + def output_column_name(self, instance_spec: LinkableInstanceSpec) -> Optional[str]: + raise NotImplementedError("Using this naming scheme for naming output columns is not yet supported.") + + @override + def spec_pattern(self, input_str: str) -> SpecPattern[LinkableInstanceSpec]: + try: + # TODO: Update when more appropriate parsing libraries are available. + call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets + except ParseWhereFilterException as e: + raise RuntimeError(f"A spec pattern can't be generated from the input string `{input_str}`") from e + + num_parameter_sets = ( + len(call_parameter_sets.dimension_call_parameter_sets) + + len(call_parameter_sets.time_dimension_call_parameter_sets) + + len(call_parameter_sets.entity_call_parameter_sets) + ) + if num_parameter_sets != 1: + raise RuntimeError(f"Did not find exactly 1 call parameter set. Got: {num_parameter_sets}") + + for dimension_call_parameter_set in call_parameter_sets.dimension_call_parameter_sets: + return EntityPathPattern( + EntityPathPatternParameterSet( + element_name=dimension_call_parameter_set.dimension_reference.element_name, + entity_links=dimension_call_parameter_set.entity_path, + time_granularity=None, + date_part=None, + input_string=input_str, + naming_scheme=self, + ) + ) + + for time_dimension_call_parameter_set in call_parameter_sets.time_dimension_call_parameter_sets: + # TODO: Temporary heuristic check until date_part support is fully in. + if input_str.find("date_part"): + raise NotImplementedError("date_part support is blocked on appropriate parsing in DSI") + return EntityPathPattern( + EntityPathPatternParameterSet( + element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name, + entity_links=time_dimension_call_parameter_set.entity_path, + time_granularity=time_dimension_call_parameter_set.time_granularity, + date_part=None, + input_string=input_str, + naming_scheme=self, + ) + ) + + for entity_call_parameter_set in call_parameter_sets.entity_call_parameter_sets: + return EntityPathPattern( + EntityPathPatternParameterSet( + element_name=entity_call_parameter_set.entity_reference.element_name, + entity_links=entity_call_parameter_set.entity_path, + time_granularity=None, + date_part=None, + input_string=input_str, + naming_scheme=self, + ) + ) + + raise RuntimeError("There should have been a return associated with one of the CallParameterSets.") + + @override + def input_str_follows_scheme(self, input_str: str) -> bool: + try: + PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets, + return True + except ParseWhereFilterException: + return False + + @property + @override + def input_str_description(self) -> str: + return ( + "The input string should follow the conventions for specifying group by items in the Python object format." + ) + + +class _PythonObjectNameTransform(InstanceSpecSetTransform[str]): + """Transforms specs into inputs following the Python object scheme. + + The input set should have exactly one group by item spec. + """ + + @override + def transform(self, spec_set: InstanceSpecSet) -> str: + assert len(spec_set.metric_specs) == 0 + assert len(spec_set.metadata_specs) == 0 + assert len(spec_set.measure_specs) == 0 + assert len(spec_set.entity_specs) + len(spec_set.dimension_specs) + len(spec_set.time_dimension_specs) == 1 + + for instance_spec in spec_set.time_dimension_specs + spec_set.entity_specs + spec_set.dimension_specs: + if len(instance_spec.entity_links) == 0: + raise RuntimeError( + "The Python object naming scheme should have only been applied to specs with entity links." + ) + + for time_dimension_spec in spec_set.time_dimension_specs: + primary_entity_name = time_dimension_spec.entity_links[-1].element_name + other_entity_names = tuple( + entity_link.element_name for entity_link in time_dimension_spec.entity_links[:-1] + ) + initializer_parameters = [ + f"'{primary_entity_name}{DUNDER}{time_dimension_spec.element_name}'", + f"time_granularity_name='{time_dimension_spec.time_granularity.value}')", + ] + if time_dimension_spec.date_part is not None: + initializer_parameters.append(f"date_part_name='{time_dimension_spec.date_part.value}'") + initializer_parameters.append(f"entity_path={str(other_entity_names)}") + initializer_parameter_str = ", ".join(initializer_parameters) + return f"TimeDimension({initializer_parameter_str})" + + raise RuntimeError(f"There should have been a return associated with one of the specs in {spec_set}") diff --git a/metricflow/specs/patterns/similarity.py b/metricflow/specs/patterns/similarity.py new file mode 100644 index 0000000000..9228d5febe --- /dev/null +++ b/metricflow/specs/patterns/similarity.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +import rapidfuzz + + +@dataclass(frozen=True) +class ScoredItem: # noqa: D + item_str: str + score: float + + +def top_fuzzy_matches( + item: str, + candidate_items: Sequence[str], + max_suggestions: int = 6, +) -> Sequence[ScoredItem]: + """Return the top items (by edit distance) in candidate_items that fuzzy matches the given item. + + Return scores from -1 -> 0 inclusive. + """ + normalized_scored_items = [] + + # Rank choices by edit distance score. + # extract() returns a tuple like (name, score) + top_ranked_suggestions = sorted( + rapidfuzz.process.extract( + # This scorer seems to return the best results. + item, + list(candidate_items), + limit=max_suggestions, + scorer=rapidfuzz.fuzz.token_set_ratio, + ), + # Put the highest scoring item at the top of the list. + key=lambda x: x[1], + reverse=True, + ) + + for fuzz_tuple in top_ranked_suggestions: + value = fuzz_tuple[0] + score = fuzz_tuple[1] + + # fuzz scores from 0..100 so normalize non-exact matches to to -1..0 + normalized_scored_items.append(ScoredItem(item_str=value, score=-(100.0 - score) / 100.0)) + + return normalized_scored_items diff --git a/metricflow/specs/patterns/spec_pattern.py b/metricflow/specs/patterns/spec_pattern.py index eea43d7af8..0883f321c7 100644 --- a/metricflow/specs/patterns/spec_pattern.py +++ b/metricflow/specs/patterns/spec_pattern.py @@ -2,39 +2,73 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Sequence, TypeVar +from typing import Generic, Optional, Sequence, TypeVar from dbt_semantic_interfaces.pretty_print import pformat_big_objects +from metricflow.instances import SpecT +from metricflow.specs.patterns.similarity import top_fuzzy_matches from metricflow.specs.specs import InstanceSpec -class QueryInterfaceItemNamingScheme(ABC): - """Describes how to name items in the inputs and outputs of a query. - - For example, a user needs to input strings that specify the metrics and group by items. These can be in different - formats like 'user__country' or "TimeDimension('metric_time', 'DAY')" - """ +class SpecPattern(ABC, Generic[SpecT]): + """A pattern is used to select a spec from a group of specs based on class-defined criteria.""" @abstractmethod - def input_str(self, instance_spec: InstanceSpec) -> Optional[str]: - """Following this scheme, return the string that can be used as an input that would specify the given spec.""" + def score(self, candidate_specs: Sequence[SpecT]) -> ScoringResults: + """Given a sequence of instance specs, try to match them to this pattern and return the associated scores.""" pass + @property @abstractmethod - def output_column_str(self, instance_spec: InstanceSpec) -> str: - """Following this scheme, return the name of the column containing the item with the given spec.""" + def naming_scheme(self) -> QueryItemNamingScheme[SpecT]: + """The naming scheme used for this pattern. Used to generate suggestions in error messages.""" pass - @abstractmethod - def spec_pattern(self, input_str: str) -> SpecPattern: - """Given that the input follows this scheme, return a spec pattern that can be used to resolve a query.""" - pass + @staticmethod + def make_scoring_results( + matching_specs: Sequence[InstanceSpec], + non_matching_specs: Sequence[InstanceSpec], + input_str: str, + naming_scheme: QueryItemNamingScheme, + ) -> ScoringResults: + """Creates a result where matching specs are given a score of 1.0 and the rest are scored by edit distance. + + The edit distance is between input_str and the input strings generated by the naming scheme for the non-matching + specs. This is useful for generating suggestions in error messages. + + The result is returned in order with the highest scores first. + """ + scored_specs = [] + for matching_spec in matching_specs: + scored_specs.append( + ScoredSpec( + spec=matching_spec, + score=1.0, + ) + ) - @abstractmethod - def is_valid_input_str(self, input_str: str) -> bool: - """Returns true if the given input string follows this naming scheme.""" - pass + # For all non-matching specs, score them based on the edit distance. + non_matching_spec_name_to_spec = {} + for spec in non_matching_specs: + input_str_for_spec = naming_scheme.input_str(spec) + if input_str_for_spec is not None: + non_matching_spec_name_to_spec[input_str_for_spec] = spec + + top_scored_items = sorted( + top_fuzzy_matches( + item=input_str, + candidate_items=tuple(non_matching_spec_name_to_spec.keys()), + ), + key=lambda item: -item.score, + ) + + for scored_item in top_scored_items: + scored_specs.append( + ScoredSpec(spec=non_matching_spec_name_to_spec[scored_item.item_str], score=scored_item.score) + ) + + return ScoringResults(scored_specs=tuple(scored_specs)) @dataclass @@ -53,7 +87,7 @@ class ScoredSpec: @property def matches(self) -> bool: # noqa: D - return self.score > 0 + return self.score > 0.0 SelfTypeT = TypeVar("SelfTypeT", bound="SpecPattern") @@ -70,24 +104,58 @@ def matched_specs(self) -> Sequence[InstanceSpec]: # noqa: D return tuple(scored_spec.spec for scored_spec in self.scored_specs if scored_spec.matches) @property - def has_one_match(self) -> bool: # noqa: D + def has_exactly_one_match(self) -> bool: # noqa: D return len(self.matched_specs) == 1 @property def matching_spec(self) -> InstanceSpec: """If there is exactly one spec that matched, return it. Otherwise, raise a RuntimeError.""" matched_specs = self.matched_specs - if len(matched_specs) == 1: - raise RuntimeError( - f"This result not contain a spec that matches. Got:\n{pformat_big_objects(self.scored_specs)}" - ) + if len(matched_specs) != 1: + raise RuntimeError(f"This result not contain exactly 1 match. Got:\n{pformat_big_objects(matched_specs)}") return matched_specs[0] -class SpecPattern(ABC): - """A pattern is used to select a spec from a group of specs based on class-defined criteria.""" +class QueryItemNamingScheme(ABC, Generic[SpecT]): + """Describes how to name items that are involved in a MetricFlow query. + + Items in a query can be anything that is associated with a spec, for example metrics and group by items. + These items can be described in different string representations like "user__country" or + "TimeDimension('metric_time', 'DAY')". + + The generic parameter SpecT determines the types of instance specs that this naming scheme applies to. + """ @abstractmethod - def score(self, candidate_specs: Sequence[InstanceSpec]) -> ScoringResults: - """Given a group of instance specs, try to match them to this pattern and return the associated scores.""" + def input_str(self, instance_spec: SpecT) -> Optional[str]: + """Following this scheme, return the string that can be used as an input that would specify the given spec. + + If this scheme cannot accommodate the spec, return None. This is needed to handle a case with DatePart in + DunderNamingScheme, but naming schemes should otherwise be complete. + """ + pass + + @abstractmethod + def output_column_name(self, instance_spec: SpecT) -> Optional[str]: + """Following this scheme, return the name of the column containing the item with the given spec. + + Returns None if this naming scheme can't generate a column name for the given spec. (PythonObjectNamingScheme + does not yet support use in generating column names) + """ + pass + + @abstractmethod + def spec_pattern(self, input_str: str) -> SpecPattern[SpecT]: + """Given an the input follows this scheme, return a spec pattern that matches the defined behavior.""" + pass + + @abstractmethod + def input_str_follows_scheme(self, input_str: str) -> bool: + """Returns true if the given input string follows this naming scheme.""" + pass + + @property + @abstractmethod + def input_str_description(self) -> str: + """A description of this naming scheme used in error messages.""" pass diff --git a/metricflow/test/specs/test_entity_path_pattern.py b/metricflow/test/specs/test_entity_path_pattern.py new file mode 100644 index 0000000000..26cb314406 --- /dev/null +++ b/metricflow/test/specs/test_entity_path_pattern.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import logging +from typing import Sequence, Optional + +import pytest +from dbt_semantic_interfaces.pretty_print import pformat_big_objects +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity + +from metricflow.specs.patterns.dunder_scheme import DunderNamingScheme +from metricflow.specs.patterns.entity_path_pattern import EntityPathPattern, EntityPathPatternParameterSet +from metricflow.specs.patterns.spec_pattern import QueryItemNamingScheme, SpecPattern +from metricflow.specs.specs import DimensionSpec, EntitySpec, LinkableInstanceSpec +from metricflow.test.time.metric_time_dimension import MTD_SPEC_MONTH, MTD_SPEC_WEEK, MTD_SPEC_YEAR + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D + return ( + # Time dimensions + MTD_SPEC_WEEK, + MTD_SPEC_MONTH, + MTD_SPEC_YEAR, + # Dimensions + DimensionSpec( + element_name="country", + entity_links=( + EntityReference(element_name="listing"), + EntityReference(element_name="user"), + ), + ), + DimensionSpec(element_name="is_instant", entity_links=(EntityReference(element_name="booking"),)), + # Entities + EntitySpec( + element_name="listing", + entity_links=(EntityReference(element_name="booking"),), + ), + EntitySpec( + element_name="host", + entity_links=(EntityReference(element_name="booking"),), + ), + ) + + +def compare_scoring_results( # noqa: D + specs: Sequence[LinkableInstanceSpec], + naming_scheme: QueryItemNamingScheme[LinkableInstanceSpec], + pattern: SpecPattern[LinkableInstanceSpec], + expected_spec_strs: Sequence[str], + expected_matching_spec: Optional[LinkableInstanceSpec], +) -> None: + result = pattern.score(specs) + + if expected_matching_spec is not None: + assert result.has_exactly_one_match + assert result.matching_spec == expected_matching_spec + else: + assert not result.has_exactly_one_match + + # TODO: Remove + logger.error( + f"Result is:\n" + pformat_big_objects( + tuple(naming_scheme.input_str(scored_spec.spec) for scored_spec in result.scored_specs) + ) + ) + + actual_spec_strs = tuple(naming_scheme.input_str(scored_spec.spec) for scored_spec in result.scored_specs) + assert actual_spec_strs == expected_spec_strs + + +def test_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + naming_scheme = DunderNamingScheme() + compare_scoring_results( + specs=specs, + naming_scheme=DunderNamingScheme(), + pattern=EntityPathPattern( + EntityPathPatternParameterSet( + element_name="is_instant", + entity_links=(EntityReference(element_name="booking"),), + time_granularity=None, + date_part=None, + input_string="booking__is_instant", + naming_scheme=naming_scheme, + ) + ), + expected_spec_strs=( + 'booking__is_instant', + 'booking__listing', + 'booking__host', + 'listing__user__country', + 'metric_time__month', + 'metric_time__year', + 'metric_time__week' + ), + expected_matching_spec=DimensionSpec( + element_name="is_instant", entity_links=(EntityReference(element_name="booking"),) + ) + ) + + +def test_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + naming_scheme = DunderNamingScheme() + pattern = EntityPathPattern( + EntityPathPatternParameterSet( + element_name="is_instant", + entity_links=(EntityReference(element_name="booking"),), + time_granularity=None, + date_part=None, + input_string="booking__is_instant", + naming_scheme=naming_scheme, + ) + ) + + result = pattern.score(specs) + assert result.has_exactly_one_match + assert result.matching_spec == DimensionSpec( + element_name="is_instant", entity_links=(EntityReference(element_name="booking"),) + ) + + # TODO: Remove + logger.error( + f"Result is:\n" + pformat_big_objects( + tuple(naming_scheme.input_str(scored_spec.spec) for scored_spec in result.scored_specs) + ) + ) + + assert tuple(naming_scheme.input_str(scored_spec.spec) for scored_spec in result.scored_specs) == ( + 'booking__is_instant', + 'booking__listing', + 'booking__host', + 'listing__user__country', + 'metric_time__month', + 'metric_time__year', + 'metric_time__week' + ) + + +def test_entity_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + naming_scheme = DunderNamingScheme() + compare_scoring_results( + specs=specs, + naming_scheme=DunderNamingScheme(), + pattern=EntityPathPattern( + EntityPathPatternParameterSet( + element_name="listing", + entity_links=(EntityReference(element_name="booking"),), + time_granularity=None, + date_part=None, + input_string="booking__listing", + naming_scheme=naming_scheme, + ) + ), + expected_spec_strs=( + 'booking__listing', + 'booking__host', + 'booking__is_instant', + 'listing__user__country', + 'metric_time__month', + 'metric_time__week', + 'metric_time__year'), + expected_matching_spec=EntitySpec( + element_name="listing", entity_links=(EntityReference(element_name="booking"),) + ) + ) + + +def test_time_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + naming_scheme = DunderNamingScheme() + compare_scoring_results( + specs=specs, + naming_scheme=DunderNamingScheme(), + pattern=EntityPathPattern( + EntityPathPatternParameterSet( + element_name="metric_time", + entity_links=(), + time_granularity=TimeGranularity.WEEK, + date_part=None, + input_string="metric_time__week", + naming_scheme=naming_scheme, + ) + ), + expected_spec_strs=( + 'metric_time__week', + 'metric_time__year', + 'metric_time__month', + 'listing__user__country', + 'booking__listing', + 'booking__is_instant', + 'booking__host' + ), + expected_matching_spec=MTD_SPEC_WEEK, + ) + + +def test_time_dimension_match_without_specified(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + naming_scheme = DunderNamingScheme() + compare_scoring_results( + specs=specs, + naming_scheme=DunderNamingScheme(), + pattern=EntityPathPattern( + EntityPathPatternParameterSet( + element_name="metric_time", + entity_links=(), + time_granularity=TimeGranularity.WEEK, + date_part=None, + input_string="metric_time__week", + naming_scheme=naming_scheme, + ) + ), + expected_spec_strs=( + 'metric_time__week', + 'metric_time__year', + 'metric_time__month', + 'listing__user__country', + 'booking__listing', + 'booking__is_instant', + 'booking__host' + ), + expected_matching_spec=MTD_SPEC_WEEK, + ) \ No newline at end of file