Skip to content

Commit

Permalink
Updates to reflect a single input measure per metric (#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb authored Nov 6, 2023
1 parent e452ad5 commit 043bbbb
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 261 deletions.
137 changes: 40 additions & 97 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import collections
import logging
import time
from dataclasses import dataclass
from typing import DefaultDict, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.pretty_print import pformat_big_objects
Expand All @@ -15,7 +14,6 @@
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.dag.id_generation import DATAFLOW_PLAN_PREFIX, IdGeneratorRegistry
from metricflow.dataflow.builder.measure_additiveness import group_measure_specs_by_additiveness
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.builder.node_evaluator import (
JoinLinkableInstancesRecipe,
Expand All @@ -30,7 +28,6 @@
ConstrainTimeRangeNode,
DataflowPlan,
FilterElementsNode,
JoinAggregatedMeasuresByGroupByColumnsNode,
JoinDescription,
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
Expand Down Expand Up @@ -253,18 +250,22 @@ def _build_metrics_output_node(
metric_reference=metric_reference,
column_association_resolver=self._column_association_resolver,
)
assert (
len(metric_input_measure_specs) == 1
), "Simple and cumulative metrics must have one input measure."
metric_input_measure_spec = metric_input_measure_specs[0]

logger.info(
f"For {metric_spec}, needed measures are:\n"
f"{pformat_big_objects(metric_input_measure_specs=metric_input_measure_specs)}"
f"For {metric_spec}, needed measure is:\n"
f"{pformat_big_objects(metric_input_measure_spec=metric_input_measure_spec)}"
)
combined_where = where_constraint
if metric_spec.constraint:
combined_where = (
combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint
)
aggregated_measures_node = self.build_aggregated_measures(
metric_input_measure_specs=metric_input_measure_specs,
aggregated_measures_node = self.build_aggregated_measure(
metric_input_measure_spec=metric_input_measure_spec,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=combined_where,
Expand Down Expand Up @@ -632,9 +633,9 @@ def build_computed_metrics_node(
metric_specs=[metric_spec],
)

def build_aggregated_measures(
def build_aggregated_measure(
self,
metric_input_measure_specs: Sequence[MetricInputMeasureSpec],
metric_input_measure_spec: MetricInputMeasureSpec,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
Expand All @@ -649,81 +650,29 @@ def build_aggregated_measures(
a composite set of aggregations originating from multiple semantic models, and joined into a single
aggregated set of measures.
"""
output_nodes: List[BaseOutput] = []
semantic_models_and_constraints_to_measures: DefaultDict[
tuple[str, Optional[WhereFilterSpec]], List[MetricInputMeasureSpec]
] = collections.defaultdict(list)
for input_spec in metric_input_measure_specs:
semantic_model_names = [
dsource.name
for dsource in self._semantic_model_lookup.get_semantic_models_for_measure(
measure_reference=input_spec.measure_spec.as_reference
)
]
assert (
len(semantic_model_names) == 1
), f"Validation should enforce one semantic model per measure, but found {semantic_model_names} for {input_spec}!"
semantic_models_and_constraints_to_measures[(semantic_model_names[0], input_spec.constraint)].append(
input_spec
)

for (semantic_model, measure_constraint), measures in semantic_models_and_constraints_to_measures.items():
logger.info(
f"Building aggregated measures for {semantic_model}. "
f" Input measures: {measures} with constraints: {measure_constraint}"
)
if measure_constraint is None:
node_where_constraint = where_constraint
elif where_constraint is None:
node_where_constraint = measure_constraint
else:
node_where_constraint = where_constraint.combine(measure_constraint)

input_specs_by_measure_spec = {spec.measure_spec: spec for spec in measures}
grouped_measures_by_additiveness = group_measure_specs_by_additiveness(
tuple(input_specs_by_measure_spec.keys())
)
measures_by_additiveness = grouped_measures_by_additiveness.measures_by_additiveness

# Build output nodes for each distinct non-additive dimension spec, including the None case
for non_additive_spec, measure_specs in measures_by_additiveness.items():
non_additive_message = ""
if non_additive_spec is not None:
non_additive_message = f" with non-additive dimension spec: {non_additive_spec}"

logger.info(f"Building aggregated measures for {semantic_model}{non_additive_message}")
input_specs = tuple(input_specs_by_measure_spec[measure_spec] for measure_spec in measure_specs)
output_nodes.append(
self._build_aggregated_measures_from_measure_source_node(
metric_input_measure_specs=input_specs,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=node_where_constraint,
time_range_constraint=time_range_constraint,
cumulative=cumulative,
cumulative_window=cumulative_window,
cumulative_grain_to_date=cumulative_grain_to_date,
)
)

if len(output_nodes) == 1:
return output_nodes[0]
measure_constraint = metric_input_measure_spec.constraint
logger.info(f"Building aggregated measure: {metric_input_measure_spec} with constraint: {measure_constraint}")
if measure_constraint is None:
node_where_constraint = where_constraint
elif where_constraint is None:
node_where_constraint = measure_constraint
else:
return FilterElementsNode(
parent_node=JoinAggregatedMeasuresByGroupByColumnsNode(parent_nodes=output_nodes),
include_specs=InstanceSpecSet.merge(
(
queried_linkable_specs.as_spec_set,
InstanceSpecSet(
measure_specs=tuple(x.post_aggregation_spec for x in metric_input_measure_specs)
),
)
),
)
node_where_constraint = where_constraint.combine(measure_constraint)

return self._build_aggregated_measure_from_measure_source_node(
metric_input_measure_spec=metric_input_measure_spec,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=node_where_constraint,
time_range_constraint=time_range_constraint,
cumulative=cumulative,
cumulative_window=cumulative_window,
cumulative_grain_to_date=cumulative_grain_to_date,
)

def _build_aggregated_measures_from_measure_source_node(
def _build_aggregated_measure_from_measure_source_node(
self,
metric_input_measure_specs: Sequence[MetricInputMeasureSpec],
metric_input_measure_spec: MetricInputMeasureSpec,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
Expand All @@ -738,8 +687,8 @@ def _build_aggregated_measures_from_measure_source_node(
if time_dimension_spec.element_name == self._metric_time_dimension_reference.element_name
]
metric_time_dimension_requested = len(metric_time_dimension_specs) > 0
measure_specs = tuple(x.measure_spec for x in metric_input_measure_specs)
measure_properties = self._build_measure_spec_properties(measure_specs)
measure_spec = metric_input_measure_spec.measure_spec
measure_properties = self._build_measure_spec_properties([measure_spec])
non_additive_dimension_spec = measure_properties.non_additive_dimension_spec

cumulative_metric_adjusted_time_constraint: Optional[TimeRangeConstraint] = None
Expand Down Expand Up @@ -774,7 +723,7 @@ def _build_aggregated_measures_from_measure_source_node(
required_linkable_specs = LinkableSpecSet.merge((queried_linkable_specs, extraneous_linkable_specs))
logger.info(
f"Looking for a recipe to get:\n"
f"{pformat_big_objects(measure_specs=measure_specs, required_linkable_set=required_linkable_specs)}"
f"{pformat_big_objects(measure_specs=[measure_spec], required_linkable_set=required_linkable_specs)}"
)

find_recipe_start_time = time.time()
Expand All @@ -793,7 +742,7 @@ def _build_aggregated_measures_from_measure_source_node(
if not measure_recipe:
# TODO: Improve for better user understandability.
raise UnableToSatisfyQueryError(
f"Recipe not found for measure specs: {measure_specs} and linkable specs: {required_linkable_specs}"
f"Recipe not found for measure spec: {measure_spec} and linkable specs: {required_linkable_specs}"
)

# If a cumulative metric is queried with metric_time, join over time range.
Expand Down Expand Up @@ -825,7 +774,7 @@ def _build_aggregated_measures_from_measure_source_node(
parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node,
include_specs=InstanceSpecSet.merge(
(
InstanceSpecSet(measure_specs=measure_specs),
InstanceSpecSet(measure_specs=(measure_spec,)),
InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs),
)
),
Expand All @@ -841,7 +790,7 @@ def _build_aggregated_measures_from_measure_source_node(

specs_to_keep_after_join = InstanceSpecSet.merge(
(
InstanceSpecSet(measure_specs=measure_specs),
InstanceSpecSet(measure_specs=(measure_spec,)),
required_linkable_specs.as_spec_set,
)
)
Expand Down Expand Up @@ -902,22 +851,16 @@ def _build_aggregated_measures_from_measure_source_node(
pre_aggregate_node = FilterElementsNode(
parent_node=pre_aggregate_node,
include_specs=InstanceSpecSet.merge(
(InstanceSpecSet(measure_specs=measure_specs), queried_linkable_specs.as_spec_set)
(InstanceSpecSet(measure_specs=(measure_spec,)), queried_linkable_specs.as_spec_set)
),
)
aggregate_measures_node = AggregateMeasuresNode(
parent_node=pre_aggregate_node,
metric_input_measure_specs=tuple(metric_input_measure_specs),
metric_input_measure_specs=(metric_input_measure_spec,),
)

join_aggregated_measure_to_time_spine = False
for metric_input_measure in metric_input_measure_specs:
if metric_input_measure.join_to_timespine:
join_aggregated_measure_to_time_spine = True
break

# Only join to time spine if metric time was requested in the query.
if join_aggregated_measure_to_time_spine and metric_time_dimension_requested:
if metric_input_measure_spec.join_to_timespine and metric_time_dimension_requested:
return JoinToTimeSpineNode(
parent_node=aggregate_measures_node,
requested_metric_time_dimension_specs=metric_time_dimension_specs,
Expand Down
61 changes: 0 additions & 61 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
DATAFLOW_NODE_COMBINE_METRICS_ID_PREFIX,
DATAFLOW_NODE_COMPUTE_METRICS_ID_PREFIX,
DATAFLOW_NODE_CONSTRAIN_TIME_RANGE_ID_PREFIX,
DATAFLOW_NODE_JOIN_AGGREGATED_MEASURES_BY_GROUPBY_COLUMNS_PREFIX,
DATAFLOW_NODE_JOIN_SELF_OVER_TIME_RANGE_ID_PREFIX,
DATAFLOW_NODE_JOIN_TO_STANDARD_OUTPUT_ID_PREFIX,
DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX,
Expand Down Expand Up @@ -122,12 +121,6 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> VisitorOutputT: # noqa:
def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> VisitorOutputT: # noqa: D
pass

@abstractmethod
def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D
self, node: JoinAggregatedMeasuresByGroupByColumnsNode
) -> VisitorOutputT:
pass

@abstractmethod
def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> VisitorOutputT: # noqa: D
pass
Expand Down Expand Up @@ -499,60 +492,6 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> AggregateM
)


class JoinAggregatedMeasuresByGroupByColumnsNode(AggregatedMeasuresOutput):
"""A node that joins aggregated measures with group by elements.
This is designed to link two separate semantic models with measures aggregated by the complete set of group by
elements shared across both measures. Due to the way the DataflowPlan currently processes joins, this means
each separate semantic model will be pre-aggregated, and this final join will be run across fully aggregated
sets of input data. As such, all this requires is the list of aggregated measure outputs, since they can be
transformed into a SqlDataSet containing the complete list of non-measure specs for joining.
"""

def __init__(
self,
parent_nodes: Sequence[BaseOutput],
):
"""Constructor.
Args:
parent_nodes: sequence of nodes that output aggregated measures
"""
if len(parent_nodes) < 2:
raise ValueError(
"This node is designed for joining 2 or more aggregated nodes together, but "
f"we got {len(parent_nodes)}"
)
super().__init__(node_id=self.create_unique_id(), parent_nodes=list(parent_nodes))

@classmethod
def id_prefix(cls) -> str: # noqa: D
return DATAFLOW_NODE_JOIN_AGGREGATED_MEASURES_BY_GROUPBY_COLUMNS_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D
return visitor.visit_join_aggregated_measures_by_groupby_columns_node(self)

@property
def description(self) -> str: # noqa: D
return """Join Aggregated Measures with Standard Outputs"""

@property
def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D
return super().displayed_properties + [
DisplayedProperty("Join aggregated measure nodes: ", f"{[node.node_id for node in self.parent_nodes]}")
]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D
return isinstance(other_node, self.__class__)

def with_new_parents( # noqa: D
self, new_parent_nodes: Sequence[BaseOutput]
) -> JoinAggregatedMeasuresByGroupByColumnsNode:
return JoinAggregatedMeasuresByGroupByColumnsNode(
parent_nodes=new_parent_nodes,
)


class SemiAdditiveJoinNode(BaseOutput):
"""A node that performs a row filter by aggregating a given non-additive dimension.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
DataflowPlanNode,
DataflowPlanNodeVisitor,
FilterElementsNode,
JoinAggregatedMeasuresByGroupByColumnsNode,
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
JoinToTimeSpineNode,
Expand Down Expand Up @@ -216,12 +215,6 @@ def visit_join_to_base_output_node( # noqa: D
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D
self, node: JoinAggregatedMeasuresByGroupByColumnsNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_aggregate_measures_node( # noqa: D
self, node: AggregateMeasuresNode
) -> ComputeMetricsBranchCombinerResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
DataflowPlanNode,
DataflowPlanNodeVisitor,
FilterElementsNode,
JoinAggregatedMeasuresByGroupByColumnsNode,
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
JoinToTimeSpineNode,
Expand Down Expand Up @@ -155,12 +154,6 @@ def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> Optimize
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D
self, node: JoinAggregatedMeasuresByGroupByColumnsNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> OptimizeBranchResult: # noqa: D
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
Expand Down
Loading

0 comments on commit 043bbbb

Please sign in to comment.