Skip to content

Commit

Permalink
Stronger typing for read_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Oct 12, 2023
1 parent 20cc790 commit d350fc6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
5 changes: 3 additions & 2 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
JoinToBaseOutputNode,
JoinToTimeSpineNode,
OrderByLimitNode,
ReadSqlSourceNode,
SemiAdditiveJoinNode,
SinkOutput,
WhereConstraintNode,
Expand Down Expand Up @@ -144,7 +145,7 @@ class DataflowPlanBuilder:
def __init__( # noqa: D
self,
source_nodes: Sequence[BaseOutput],
read_nodes: Sequence[BaseOutput],
read_nodes: Sequence[ReadSqlSourceNode],
semantic_manifest_lookup: SemanticManifestLookup,
cost_function: DataflowPlanNodeCostFunction = DefaultCostFunction(),
node_output_resolver: Optional[DataflowPlanNodeOutputDataSetResolver] = None,
Expand Down Expand Up @@ -418,7 +419,7 @@ def _select_source_nodes_with_measures(
return nodes

def _select_read_nodes_with_linkable_specs(
self, linkable_specs: LinkableSpecSet, read_nodes: Sequence[BaseOutput]
self, linkable_specs: LinkableSpecSet, read_nodes: Sequence[ReadSqlSourceNode]
) -> Dict[BaseOutput, Set[LinkableInstanceSpec]]:
"""Find source nodes with requested linkable specs and no measures."""
nodes_to_linkable_specs: Dict[BaseOutput, Set[LinkableInstanceSpec]] = {}
Expand Down
4 changes: 3 additions & 1 deletion metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> Se
)
return source_nodes

def create_read_nodes_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> Sequence[BaseOutput]:
def create_read_nodes_from_data_sets(
self, data_sets: Sequence[SemanticModelDataSet]
) -> Sequence[ReadSqlSourceNode]:
"""Creates read nodes from SemanticModelDataSets."""
return [ReadSqlSourceNode(data_set) for data_set in data_sets]
4 changes: 2 additions & 2 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from metricflow.assert_one_arg import assert_exactly_one_arg_set
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataflow.dataflow_plan import ReadSqlSourceNode
from metricflow.dataset.dataset import DataSet
from metricflow.errors.errors import UnableToSatisfyQueryError
from metricflow.filters.time_constraint import TimeRangeConstraint
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__( # noqa: D
self,
column_association_resolver: ColumnAssociationResolver,
model: SemanticManifestLookup,
read_nodes: Sequence[BaseOutput],
read_nodes: Sequence[ReadSqlSourceNode],
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
) -> None:
self._column_association_resolver = column_association_resolver
Expand Down
8 changes: 4 additions & 4 deletions metricflow/time/time_granularity_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataflow.dataflow_plan import ReadSqlSourceNode
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
Expand Down Expand Up @@ -103,7 +103,7 @@ def resolve_granularity_for_partial_time_dimension_specs(
self,
metric_references: Sequence[MetricReference],
partial_time_dimension_specs: Sequence[PartialTimeDimensionSpec],
read_nodes: Sequence[BaseOutput],
read_nodes: Sequence[ReadSqlSourceNode],
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
) -> Dict[PartialTimeDimensionSpec, TimeDimensionSpec]:
"""Figure out the lowest granularity possible for the partially specified time dimension specs.
Expand Down Expand Up @@ -131,7 +131,7 @@ def find_minimum_granularity_for_partial_time_dimension_spec(
self,
partial_time_dimension_spec: PartialTimeDimensionSpec,
metric_references: Sequence[MetricReference],
read_nodes: Sequence[BaseOutput],
read_nodes: Sequence[ReadSqlSourceNode],
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
) -> TimeGranularity:
"""Find minimum granularity allowed for time dimension when queried with given metrics."""
Expand Down Expand Up @@ -173,7 +173,7 @@ def find_minimum_granularity_for_partial_time_dimension_spec(

def get_min_granularity_for_partial_time_dimension_without_metrics(
self,
read_nodes: Sequence[BaseOutput],
read_nodes: Sequence[ReadSqlSourceNode],
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
partial_time_dimension_spec: PartialTimeDimensionSpec,
) -> Optional[TimeGranularity]:
Expand Down

0 comments on commit d350fc6

Please sign in to comment.