Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Querying dimensions without metrics #804

Merged
merged 15 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231010-174851.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support querying dimensions without metrics.
time: 2023-10-10T17:48:51.152712-07:00
custom:
Author: courtneyholcomb
Issue: "804"
6 changes: 5 additions & 1 deletion metricflow/dataflow/builder/costing.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> Defa
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])
parent_costs = [x.accept(self) for x in node.parent_nodes]

# 1 aggregation if grouping by distinct values
node_cost = DefaultCost(num_aggregations=1 if node.distinct else 0)
return DefaultCost.sum(parent_costs + [node_cost])

def visit_combine_metrics_node(self, node: CombineMetricsNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])
Expand Down
320 changes: 149 additions & 171 deletions metricflow/dataflow/builder/dataflow_plan_builder.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,7 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> Se
)
)
return source_nodes

def create_read_nodes_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> Sequence[BaseOutput]:
"""Creates read nodes from SemanticModelDataSets."""
return [ReadSqlSourceNode(data_set) for data_set in data_sets]
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 14 additions & 2 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,10 +1098,12 @@ def __init__( # noqa: D
parent_node: BaseOutput,
include_specs: InstanceSpecSet,
replace_description: Optional[str] = None,
distinct: bool = False,
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self._include_specs = include_specs
self._replace_description = replace_description
self._parent_node = parent_node
self._distinct = distinct
super().__init__(node_id=self.create_unique_id(), parent_nodes=[parent_node])

@classmethod
Expand All @@ -1113,6 +1115,11 @@ def include_specs(self) -> InstanceSpecSet:
"""Returns the specs for the elements that it should pass."""
return self._include_specs

@property
def distinct(self) -> bool:
"""True if you only want the distinct values for the selected specs."""
return self._distinct

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D
return visitor.visit_pass_elements_filter_node(self)

Expand All @@ -1132,21 +1139,26 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D
if not self._replace_description:
additional_properties = [
DisplayedProperty("include_spec", include_spec) for include_spec in self._include_specs.all_specs
]
] + [DisplayedProperty("distinct", self._distinct)]
return super().displayed_properties + additional_properties

@property
def parent_node(self) -> BaseOutput: # noqa: D
return self._parent_node

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D
return isinstance(other_node, self.__class__) and other_node.include_specs == self.include_specs
return (
isinstance(other_node, self.__class__)
and other_node.include_specs == self.include_specs
and other_node.distinct == self.distinct
)

def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> FilterElementsNode: # noqa: D
assert len(new_parent_nodes) == 1
return FilterElementsNode(
parent_node=new_parent_nodes[0],
include_specs=self.include_specs,
distinct=self.distinct,
replace_description=self._replace_description,
)

Expand Down
19 changes: 12 additions & 7 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(

source_node_builder = SourceNodeBuilder(self._semantic_manifest_lookup)
source_nodes = source_node_builder.create_from_data_sets(self._source_data_sets)
read_nodes = source_node_builder.create_read_nodes_from_data_sets(self._source_data_sets)

node_output_resolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup),
Expand All @@ -358,6 +359,7 @@ def __init__(

self._dataflow_plan_builder = DataflowPlanBuilder(
source_nodes=source_nodes,
read_nodes=read_nodes,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
self._to_sql_query_plan_converter = DataflowToSqlQueryPlanConverter(
Expand All @@ -374,7 +376,7 @@ def __init__(
self._query_parser = MetricFlowQueryParser(
column_association_resolver=self._column_association_resolver,
model=self._semantic_manifest_lookup,
source_nodes=source_nodes,
read_nodes=read_nodes,
node_output_resolver=node_output_resolver,
)

Expand Down Expand Up @@ -483,12 +485,15 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
time_dimension_specs=query_spec.time_dimension_specs,
)

dataflow_plan = self._dataflow_plan_builder.build_plan(
query_spec=query_spec,
output_sql_table=output_table,
output_selection_specs=output_selection_specs,
optimizers=(SourceScanOptimizer(),),
)
if query_spec.metric_specs:
dataflow_plan = self._dataflow_plan_builder.build_plan(
query_spec=query_spec,
output_sql_table=output_table,
output_selection_specs=output_selection_specs,
optimizers=(SourceScanOptimizer(),),
)
else:
dataflow_plan = self._dataflow_plan_builder.build_plan_for_distinct_values(query_spec=query_spec)

if len(dataflow_plan.sink_output_nodes) > 1:
raise NotImplementedError(
Expand Down
14 changes: 9 additions & 5 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,18 +797,22 @@ def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> SqlDataSe
# Also, the output columns should always follow the resolver format.
output_instance_set = output_instance_set.transform(ChangeAssociatedColumns(self._column_association_resolver))

# This creates select expressions for all columns referenced in the instance set.
select_columns = output_instance_set.transform(
CreateSelectColumnsForInstances(from_data_set_alias, self._column_association_resolver)
).as_tuple()

# If distinct values requested, group by all select columns.
group_bys = select_columns if node.distinct else ()
return SqlDataSet(
instance_set=output_instance_set,
sql_select_node=SqlSelectStatementNode(
description=node.description,
# This creates select expressions for all columns referenced in the instance set.
select_columns=output_instance_set.transform(
CreateSelectColumnsForInstances(from_data_set_alias, self._column_association_resolver)
).as_tuple(),
select_columns=select_columns,
from_source=from_data_set.sql_select_node,
from_source_alias=from_data_set_alias,
joins_descs=(),
group_bys=(),
group_bys=group_bys,
where=None,
order_bys=(),
),
Expand Down
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class MultiHopJoinCandidate:
lineage: MultiHopJoinCandidateLineage


class PreDimensionJoinNodeProcessor:
"""Processes source nodes before measures are joined to dimensions.
class PreJoinNodeProcessor:
"""Processes source nodes before other nodes are joined.

Generally, the source nodes will be combined with other dataflow plan nodes to produce a new set of nodes to realize
a condition of the query. For example, to realize a time range constraint, a ConstrainTimeRangeNode will be added
Expand Down
28 changes: 21 additions & 7 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ def __init__( # noqa: D
self,
column_association_resolver: ColumnAssociationResolver,
model: SemanticManifestLookup,
source_nodes: Sequence[BaseOutput],
read_nodes: Sequence[BaseOutput],
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
) -> None:
self._column_association_resolver = column_association_resolver
self._model = model
self._metric_lookup = model.metric_lookup
self._semantic_model_lookup = model.semantic_model_lookup
self._node_output_resolver = node_output_resolver
self._read_nodes = read_nodes

# Set up containers for known element names
self._known_entity_element_references = self._semantic_model_lookup.get_entity_references()
Expand Down Expand Up @@ -228,12 +230,15 @@ def _validate_no_time_dimension_query(self, metric_references: Sequence[MetricRe
"dimension 'metric_time'."
)

def _validate_linkable_specs(
def _validate_linkable_specs_for_metrics(
self,
metric_references: Tuple[MetricReference, ...],
all_linkable_specs: QueryTimeLinkableSpecSet,
time_dimension_specs: Tuple[TimeDimensionSpec, ...],
) -> None:
if not metric_references:
return None

invalid_group_bys = self._get_invalid_linkable_specs(
metric_references=metric_references,
dimension_specs=all_linkable_specs.dimension_specs,
Expand Down Expand Up @@ -296,6 +301,9 @@ def _construct_metric_specs_for_query(
def _get_metric_names(
self, metric_names: Optional[Sequence[str]], metrics: Optional[Sequence[MetricQueryParameter]]
) -> Sequence[str]:
if not (metric_names or metrics):
return []

assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics)
return metric_names if metric_names else [m.name for m in metrics] if metrics else []

Expand Down Expand Up @@ -402,6 +410,8 @@ def _parse_and_validate_query(
self._time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=metric_references,
partial_time_dimension_specs=requested_linkable_specs.partial_time_dimension_specs,
read_nodes=self._read_nodes,
node_output_resolver=self._node_output_resolver,
)
)

Expand All @@ -422,7 +432,7 @@ def _parse_and_validate_query(

# For each metric, verify that it's possible to retrieve all group by elements, including the ones as required
# by the filters.
# TODO: Consider moving this logic into _validate_linkable_specs().
# TODO: Consider moving this logic into _validate_linkable_specs_for_metrics().
for metric_reference in metric_references:
metric = self._metric_lookup.get_metric(metric_reference)
if metric.filter is not None:
Expand All @@ -434,7 +444,7 @@ def _parse_and_validate_query(

# Combine the group by elements from the query with the group by elements that are required by the
# metric filter to see if that's a valid set that could be queried.
self._validate_linkable_specs(
self._validate_linkable_specs_for_metrics(
metric_references=(metric_reference,),
all_linkable_specs=QueryTimeLinkableSpecSet.combine(
(
Expand All @@ -452,7 +462,7 @@ def _parse_and_validate_query(
)

# Validate all of them together.
self._validate_linkable_specs(
self._validate_linkable_specs_for_metrics(
metric_references=metric_references,
all_linkable_specs=requested_linkable_specs_with_requested_filter_specs,
time_dimension_specs=time_dimension_specs,
Expand Down Expand Up @@ -573,6 +583,8 @@ def _adjust_time_range_constraint(
self._time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=metric_references,
partial_time_dimension_specs=(partial_metric_time_spec,),
read_nodes=self._read_nodes,
node_output_resolver=self._node_output_resolver,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint, this must be changed in a later commit one way or the other.....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean??

)
)
adjust_to_granularity = partial_time_dimension_spec_to_time_dimension_spec[
Expand Down Expand Up @@ -670,7 +682,6 @@ def _parse_group_by(
group_by: Optional[Tuple[GroupByParameter, ...]] = None,
) -> QueryTimeLinkableSpecSet:
"""Convert the linkable spec names into the respective specification objects."""
# TODO: refactor to only support group_by object inputs (removing group_by_names param)
assert not (
group_by_names and group_by
), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!"
Expand Down Expand Up @@ -772,7 +783,10 @@ def _verify_resolved_granularity_for_date_part(
ensure that the correct value was passed in.
"""
resolved_granularity = self._time_granularity_solver.find_minimum_granularity_for_partial_time_dimension_spec(
partial_time_dimension_spec=partial_time_dimension_spec, metric_references=metric_references
partial_time_dimension_spec=partial_time_dimension_spec,
metric_references=metric_references,
read_nodes=self._read_nodes,
node_output_resolver=self._node_output_resolver,
)
if resolved_granularity != requested_dimension_structured_name.time_granularity:
raise RequestTimeGranularityException(
Expand Down
1 change: 1 addition & 0 deletions metricflow/test/dataflow/builder/test_cyclic_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def cyclic_join_manifest_dataflow_plan_builder( # noqa: D

return DataflowPlanBuilder(
source_nodes=consistent_id_object_repository.cyclic_join_source_nodes,
read_nodes=list(consistent_id_object_repository.cyclic_join_read_nodes.values()),
semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup,
cost_function=DefaultCostFunction(),
)
Expand Down
78 changes: 72 additions & 6 deletions metricflow/test/dataflow/builder/test_dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,81 @@ def test_distinct_values_plan( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
column_association_resolver: ColumnAssociationResolver,
) -> None:
"""Tests a plan to get distinct values of a dimension."""
dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values(
metric_specs=(MetricSpec(element_name="bookings"),),
dimension_spec=DimensionSpec(
element_name="country_latest",
entity_links=(EntityReference(element_name="listing"),),
),
limit=100,
query_spec=MetricFlowQuerySpec(
dimension_specs=(
DimensionSpec(element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)),
),
where_constraint=(
WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(
PydanticWhereFilter(
where_sql_template="{{ Dimension('listing__country_latest') }} = 'us'",
)
)
),
order_by_specs=(
OrderBySpec(
instance_spec=DimensionSpec(
element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)
),
descending=True,
),
),
limit=100,
)
)

assert_plan_snapshot_text_equal(
request=request,
mf_test_session_state=mf_test_session_state,
plan=dataflow_plan,
plan_snapshot_text=dataflow_plan_as_text(dataflow_plan),
)

display_graph_if_requested(
request=request,
mf_test_session_state=mf_test_session_state,
dag_graph=dataflow_plan,
)


def test_distinct_values_plan_with_join( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
column_association_resolver: ColumnAssociationResolver,
) -> None:
"""Tests a plan to get distinct values of 2 dimensions, where a join is required."""
dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values(
query_spec=MetricFlowQuerySpec(
dimension_specs=(
DimensionSpec(element_name="home_state_latest", entity_links=(EntityReference(element_name="user"),)),
DimensionSpec(element_name="is_lux_latest", entity_links=(EntityReference(element_name="listing"),)),
),
where_constraint=(
WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(
PydanticWhereFilter(
where_sql_template="{{ Dimension('listing__country_latest') }} = 'us'",
)
)
),
order_by_specs=(
OrderBySpec(
instance_spec=DimensionSpec(
element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)
),
descending=True,
),
),
limit=100,
)
)

assert_plan_snapshot_text_equal(
Expand Down
4 changes: 2 additions & 2 deletions metricflow/test/dataflow/builder/test_node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from metricflow.dataset.dataset import DataSet
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver
from metricflow.plan_conversion.node_processor import PreDimensionJoinNodeProcessor
from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor
from metricflow.specs.specs import (
DimensionSpec,
EntityReference,
Expand Down Expand Up @@ -65,7 +65,7 @@ def make_multihop_node_evaluator(
semantic_manifest_lookup=semantic_manifest_lookup_with_multihop_links,
)

node_processor = PreDimensionJoinNodeProcessor(
node_processor = PreJoinNodeProcessor(
semantic_model_lookup=semantic_manifest_lookup_with_multihop_links.semantic_model_lookup,
node_data_set_resolver=node_data_set_resolver,
)
Expand Down
Loading
Loading