From df26a6b7891b33c885e70902e028e488fca0a907 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Wed, 2 Oct 2024 12:48:57 -0700 Subject: [PATCH] Reduce recursive-call overhead in `MetricTimeQueryValidationRule` (#1440) The check in `MetricTimeQueryValidationRule` is called for every metric in a derived metric's ancestors, so this moves expensive parts of the check to only where it's needed and caches results when possible to reduce runtimes. The signature for the validation rule classes were changed, so there are a number of diff lines related to that. --- .../query/query_resolver.py | 10 +++++-- .../validation_rules/base_validation_rule.py | 7 +++-- .../validation_rules/duplicate_metric.py | 7 ----- .../metric_time_requirements.py | 28 ++++++++++--------- .../query/validation_rules/query_validator.py | 19 ++++--------- 5 files changed, 31 insertions(+), 40 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/query/query_resolver.py b/metricflow-semantics/metricflow_semantics/query/query_resolver.py index 5181ebf642..c12b95f1c1 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/query_resolver.py @@ -54,6 +54,8 @@ ResolverInputForWhereFilterIntersection, ) from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator +from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule +from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule from metricflow_semantics.query.validation_rules.query_validator import PostResolutionQueryValidator from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec from metricflow_semantics.specs.metric_spec import MetricSpec @@ -123,9 +125,7 @@ def __init__( # noqa: D107 where_filter_pattern_factory: WhereFilterPatternFactory, ) -> None: self._manifest_lookup = manifest_lookup - self._post_resolution_query_validator = PostResolutionQueryValidator( - manifest_lookup=self._manifest_lookup, - ) + self._post_resolution_query_validator = PostResolutionQueryValidator() self._where_filter_pattern_factory = where_filter_pattern_factory @staticmethod @@ -491,6 +491,10 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met query_level_issue_set = self._post_resolution_query_validator.validate_query( resolution_dag=resolution_dag, resolver_input_for_query=resolver_input_for_query, + validation_rules=( + MetricTimeQueryValidationRule(self._manifest_lookup, resolver_input_for_query), + DuplicateMetricValidationRule(self._manifest_lookup, resolver_input_for_query), + ), ) if query_level_issue_set.has_issues: diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py index abfc9ab42f..7de6ebc161 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py @@ -15,8 +15,11 @@ class PostResolutionQueryValidationRule(ABC): """A validation rule that runs after all query inputs have been resolved to specs.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 + def __init__( # noqa: D107 + self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery + ) -> None: self._manifest_lookup = manifest_lookup + self._resolver_input_for_query = resolver_input_for_query def _get_metric(self, metric_reference: MetricReference) -> Metric: return self._manifest_lookup.metric_lookup.get_metric(metric_reference) @@ -25,7 +28,6 @@ def _get_metric(self, metric_reference: MetricReference) -> Metric: def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: """Given a metric that exists in a resolution DAG, check that the query is valid. @@ -39,7 +41,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: """Validate the parameters to the query are valid. diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py index 7a595f4be6..249fe57943 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py @@ -7,11 +7,9 @@ from dbt_semantic_interfaces.references import MetricReference from typing_extensions import override -from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet from metricflow_semantics.query.issues.parsing.duplicate_metric import DuplicateMetricIssue -from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule logger = logging.getLogger(__name__) @@ -20,14 +18,10 @@ class DuplicateMetricValidationRule(PostResolutionQueryValidationRule): """Validates that a query does not include the same metric multiple times.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - super().__init__(manifest_lookup=manifest_lookup) - @override def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: return MetricFlowQueryResolutionIssueSet.empty_instance() @@ -37,7 +31,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: duplicate_metric_references = tuple( diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py index 572b6f0ee7..bbb0fd37c7 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import cached_property from typing import Sequence from dbt_semantic_interfaces.enum_extension import assert_values_exhausted @@ -32,8 +33,10 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule): * Derived metrics with an offset time.g """ - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - super().__init__(manifest_lookup=manifest_lookup) + def __init__( # noqa: D107 + self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery + ) -> None: + super().__init__(manifest_lookup=manifest_lookup, resolver_input_for_query=resolver_input_for_query) self._metric_time_specs = tuple( TimeDimensionSpec.generate_possible_specs_for_time_dimension( @@ -43,13 +46,19 @@ def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D1 ) ) - def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool: - for group_by_item_input in query_resolver_input.group_by_item_inputs: + @cached_property + def _group_by_items_include_metric_time(self) -> bool: + for group_by_item_input in self._resolver_input_for_query.group_by_item_inputs: if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs): return True return False + def _query_includes_metric_time_or_agg_time_dimension(self, metric_reference: MetricReference) -> bool: + return self._group_by_items_include_metric_time or self._group_by_items_include_agg_time_dimension( + query_resolver_input=self._resolver_input_for_query, metric_reference=metric_reference + ) + def _group_by_items_include_agg_time_dimension( self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference ) -> bool: @@ -66,15 +75,9 @@ def _group_by_items_include_agg_time_dimension( def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: metric = self._get_metric(metric_reference) - query_includes_metric_time_or_agg_time_dimension = self._group_by_items_include_metric_time( - resolver_input_for_query - ) or self._group_by_items_include_agg_time_dimension( - query_resolver_input=resolver_input_for_query, metric_reference=metric_reference - ) if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION: return MetricFlowQueryResolutionIssueSet.empty_instance() @@ -86,7 +89,7 @@ def validate_metric_in_resolution_dag( metric.type_params.cumulative_type_params.window is not None or metric.type_params.cumulative_type_params.grain_to_date is not None ) - and not query_includes_metric_time_or_agg_time_dimension + and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference) ): return MetricFlowQueryResolutionIssueSet.from_issue( CumulativeMetricRequiresMetricTimeIssue.from_parameters( @@ -102,7 +105,7 @@ def validate_metric_in_resolution_dag( for input_metric in metric.input_metrics ) - if has_time_offset and not query_includes_metric_time_or_agg_time_dimension: + if has_time_offset and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference): return MetricFlowQueryResolutionIssueSet.from_issue( OffsetMetricRequiresMetricTimeIssue.from_parameters( metric_reference=metric_reference, @@ -119,7 +122,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: return MetricFlowQueryResolutionIssueSet.empty_instance() diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py index 2885ed0c74..2aff826e7e 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py @@ -4,7 +4,6 @@ from typing_extensions import override -from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow_semantics.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import ( @@ -26,27 +25,21 @@ from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule -from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule -from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule class PostResolutionQueryValidator: """Runs query validation rules after query resolution is complete.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - self._manifest_lookup = manifest_lookup - self._validation_rules = ( - MetricTimeQueryValidationRule(self._manifest_lookup), - DuplicateMetricValidationRule(self._manifest_lookup), - ) - def validate_query( - self, resolution_dag: GroupByItemResolutionDag, resolver_input_for_query: ResolverInputForQuery + self, + resolution_dag: GroupByItemResolutionDag, + resolver_input_for_query: ResolverInputForQuery, + validation_rules: Sequence[PostResolutionQueryValidationRule], ) -> MetricFlowQueryResolutionIssueSet: """Validate according to the list of configured validation rules and return a set containing issues found.""" validation_visitor = _PostResolutionQueryValidationVisitor( resolver_input_for_query=resolver_input_for_query, - validation_rules=self._validation_rules, + validation_rules=validation_rules, ) return resolution_dag.sink_node.accept(validation_visitor) @@ -83,7 +76,6 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> MetricFlow issue_sets_to_merge.append( validation_rule.validate_metric_in_resolution_dag( metric_reference=node.metric_reference, - resolver_input_for_query=self._resolver_input_for_query, resolution_path=current_traversal_path, ) ) @@ -100,7 +92,6 @@ def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> MetricFlowQu validation_rule.validate_query_in_resolution_dag( metrics_in_query=node.metrics_in_query, where_filter_intersection=node.where_filter_intersection, - resolver_input_for_query=self._resolver_input_for_query, resolution_path=current_traversal_path, ) )