From c9526e564566111831b158170d21911b2d74dd0a Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Tue, 10 Oct 2023 18:10:13 -0700 Subject: [PATCH] Write & update tests --- .../test/dataflow/builder/test_cyclic_join.py | 1 + .../builder/test_dataflow_plan_builder.py | 78 ++++++++++++++++-- .../dataflow/builder/test_node_evaluator.py | 4 +- metricflow/test/fixtures/dataflow_fixtures.py | 3 + metricflow/test/fixtures/model_fixtures.py | 15 +++- .../test/integration/configured_test_case.py | 2 +- .../test_cases/itest_dimensions.yaml | 80 +++++++++++++++++++ .../test_dataflow_to_sql_plan.py | 30 +++++-- metricflow/test/query/test_query_parser.py | 3 + .../test/time/test_time_granularity_solver.py | 22 ++++- 10 files changed, 219 insertions(+), 19 deletions(-) diff --git a/metricflow/test/dataflow/builder/test_cyclic_join.py b/metricflow/test/dataflow/builder/test_cyclic_join.py index 8762e45ce3..d35a2b3987 100644 --- a/metricflow/test/dataflow/builder/test_cyclic_join.py +++ b/metricflow/test/dataflow/builder/test_cyclic_join.py @@ -34,6 +34,7 @@ def cyclic_join_manifest_dataflow_plan_builder( # noqa: D return DataflowPlanBuilder( source_nodes=consistent_id_object_repository.cyclic_join_source_nodes, + read_nodes=list(consistent_id_object_repository.cyclic_join_read_nodes.values()), semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup, cost_function=DefaultCostFunction(), ) diff --git a/metricflow/test/dataflow/builder/test_dataflow_plan_builder.py b/metricflow/test/dataflow/builder/test_dataflow_plan_builder.py index cb97cee1e2..586e2e6d21 100644 --- a/metricflow/test/dataflow/builder/test_dataflow_plan_builder.py +++ b/metricflow/test/dataflow/builder/test_dataflow_plan_builder.py @@ -562,15 +562,81 @@ def test_distinct_values_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, dataflow_plan_builder: DataflowPlanBuilder, + column_association_resolver: ColumnAssociationResolver, ) -> None: """Tests a plan to get distinct values of a dimension.""" dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values( - metric_specs=(MetricSpec(element_name="bookings"),), - dimension_spec=DimensionSpec( - element_name="country_latest", - entity_links=(EntityReference(element_name="listing"),), - ), - limit=100, + query_spec=MetricFlowQuerySpec( + dimension_specs=( + DimensionSpec(element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)), + ), + where_constraint=( + WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter( + PydanticWhereFilter( + where_sql_template="{{ Dimension('listing__country_latest') }} = 'us'", + ) + ) + ), + order_by_specs=( + OrderBySpec( + instance_spec=DimensionSpec( + element_name="country_latest", entity_links=(EntityReference(element_name="listing"),) + ), + descending=True, + ), + ), + limit=100, + ) + ) + + assert_plan_snapshot_text_equal( + request=request, + mf_test_session_state=mf_test_session_state, + plan=dataflow_plan, + plan_snapshot_text=dataflow_plan_as_text(dataflow_plan), + ) + + display_graph_if_requested( + request=request, + mf_test_session_state=mf_test_session_state, + dag_graph=dataflow_plan, + ) + + +def test_distinct_values_plan_with_join( # noqa: D + request: FixtureRequest, + mf_test_session_state: MetricFlowTestSessionState, + dataflow_plan_builder: DataflowPlanBuilder, + column_association_resolver: ColumnAssociationResolver, +) -> None: + """Tests a plan to get distinct values of 2 dimensions, where a join is required.""" + dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values( + query_spec=MetricFlowQuerySpec( + dimension_specs=( + DimensionSpec(element_name="home_state_latest", entity_links=(EntityReference(element_name="user"),)), + DimensionSpec(element_name="is_lux_latest", entity_links=(EntityReference(element_name="listing"),)), + ), + where_constraint=( + WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter( + PydanticWhereFilter( + where_sql_template="{{ Dimension('listing__country_latest') }} = 'us'", + ) + ) + ), + order_by_specs=( + OrderBySpec( + instance_spec=DimensionSpec( + element_name="country_latest", entity_links=(EntityReference(element_name="listing"),) + ), + descending=True, + ), + ), + limit=100, + ) ) assert_plan_snapshot_text_equal( diff --git a/metricflow/test/dataflow/builder/test_node_evaluator.py b/metricflow/test/dataflow/builder/test_node_evaluator.py index 523cb416ed..c88fa7319e 100644 --- a/metricflow/test/dataflow/builder/test_node_evaluator.py +++ b/metricflow/test/dataflow/builder/test_node_evaluator.py @@ -18,7 +18,7 @@ from metricflow.dataset.dataset import DataSet from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver -from metricflow.plan_conversion.node_processor import PreDimensionJoinNodeProcessor +from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor from metricflow.specs.specs import ( DimensionSpec, EntityReference, @@ -65,7 +65,7 @@ def make_multihop_node_evaluator( semantic_manifest_lookup=semantic_manifest_lookup_with_multihop_links, ) - node_processor = PreDimensionJoinNodeProcessor( + node_processor = PreJoinNodeProcessor( semantic_model_lookup=semantic_manifest_lookup_with_multihop_links.semantic_model_lookup, node_data_set_resolver=node_data_set_resolver, ) diff --git a/metricflow/test/fixtures/dataflow_fixtures.py b/metricflow/test/fixtures/dataflow_fixtures.py index 06616aa7f1..e1af9103d2 100644 --- a/metricflow/test/fixtures/dataflow_fixtures.py +++ b/metricflow/test/fixtures/dataflow_fixtures.py @@ -34,6 +34,7 @@ def dataflow_plan_builder( # noqa: D ) -> DataflowPlanBuilder: return DataflowPlanBuilder( source_nodes=consistent_id_object_repository.simple_model_source_nodes, + read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()), semantic_manifest_lookup=simple_semantic_manifest_lookup, cost_function=DefaultCostFunction(), ) @@ -47,6 +48,7 @@ def multihop_dataflow_plan_builder( # noqa: D ) -> DataflowPlanBuilder: return DataflowPlanBuilder( source_nodes=consistent_id_object_repository.multihop_model_source_nodes, + read_nodes=list(consistent_id_object_repository.multihop_model_read_nodes.values()), semantic_manifest_lookup=multi_hop_join_semantic_manifest_lookup, cost_function=DefaultCostFunction(), ) @@ -68,6 +70,7 @@ def scd_dataflow_plan_builder( # noqa: D ) -> DataflowPlanBuilder: return DataflowPlanBuilder( source_nodes=consistent_id_object_repository.scd_model_source_nodes, + read_nodes=list(consistent_id_object_repository.scd_model_read_nodes.values()), semantic_manifest_lookup=scd_semantic_manifest_lookup, cost_function=DefaultCostFunction(), column_association_resolver=scd_column_association_resolver, diff --git a/metricflow/test/fixtures/model_fixtures.py b/metricflow/test/fixtures/model_fixtures.py index c863a1eb00..afcf7bf38e 100644 --- a/metricflow/test/fixtures/model_fixtures.py +++ b/metricflow/test/fixtures/model_fixtures.py @@ -61,11 +61,10 @@ def query_parser_from_yaml(yaml_contents: List[YamlConfigFile]) -> MetricFlowQue ).semantic_manifest ) SemanticManifestValidator[SemanticManifest]().checked_validations(semantic_manifest_lookup.semantic_manifest) - source_nodes = _data_set_to_source_nodes(semantic_manifest_lookup, create_data_sets(semantic_manifest_lookup)) return MetricFlowQueryParser( model=semantic_manifest_lookup, column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup), - source_nodes=source_nodes, + read_nodes=list(_data_set_to_read_nodes(create_data_sets(semantic_manifest_lookup)).values()), node_output_resolver=DataflowPlanNodeOutputDataSetResolver( column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup), semantic_manifest_lookup=semantic_manifest_lookup, @@ -88,6 +87,7 @@ class ConsistentIdObjectRepository: scd_model_read_nodes: OrderedDict[str, ReadSqlSourceNode] scd_model_source_nodes: Sequence[BaseOutput] + cyclic_join_read_nodes: OrderedDict[str, ReadSqlSourceNode] cyclic_join_source_nodes: Sequence[BaseOutput] @@ -122,6 +122,7 @@ def consistent_id_object_repository( scd_model_source_nodes=_data_set_to_source_nodes( semantic_manifest_lookup=scd_semantic_manifest_lookup, data_sets=scd_data_sets ), + cyclic_join_read_nodes=_data_set_to_read_nodes(cyclic_join_data_sets), cyclic_join_source_nodes=_data_set_to_source_nodes( semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup, data_sets=cyclic_join_data_sets ), @@ -239,3 +240,13 @@ def cyclic_join_semantic_manifest_lookup(template_mapping: Dict[str, str]) -> Se """Manifest that contains a potential cycle in the join graph (if not handled properly).""" build_result = load_semantic_manifest("cyclic_join_manifest", template_mapping) return SemanticManifestLookup(build_result.semantic_manifest) + + +@pytest.fixture(scope="session") +def node_output_resolver( # noqa:D + simple_semantic_manifest_lookup: SemanticManifestLookup, +) -> DataflowPlanNodeOutputDataSetResolver: + return DataflowPlanNodeOutputDataSetResolver( + column_association_resolver=DunderColumnAssociationResolver(simple_semantic_manifest_lookup), + semantic_manifest_lookup=simple_semantic_manifest_lookup, + ) diff --git a/metricflow/test/integration/configured_test_case.py b/metricflow/test/integration/configured_test_case.py index 766873b9b0..b0e59d3ab2 100644 --- a/metricflow/test/integration/configured_test_case.py +++ b/metricflow/test/integration/configured_test_case.py @@ -48,10 +48,10 @@ class Config: # noqa: D name: str # Name of the semantic model to use. model: IntegrationTestModel - metrics: Tuple[str, ...] # The SQL query that can be run to obtain the expected results. check_query: str file_path: str + metrics: Tuple[str, ...] = () group_bys: Tuple[str, ...] = () group_by_objs: Tuple[Dict, ...] = () order_bys: Tuple[str, ...] = () diff --git a/metricflow/test/integration/test_cases/itest_dimensions.yaml b/metricflow/test/integration/test_cases/itest_dimensions.yaml index ea5affa47c..c6d5912629 100644 --- a/metricflow/test/integration/test_cases/itest_dimensions.yaml +++ b/metricflow/test/integration/test_cases/itest_dimensions.yaml @@ -102,3 +102,83 @@ integration_test: GROUP BY v.ds , u.home_state +--- +integration_test: + name: query_dimension_only + description: Query dimenension only + model: SIMPLE_MODEL + group_bys: ["user__home_state"] + check_query: | + SELECT + u.home_state AS user__home_state + FROM {{ source_schema }}.dim_users u + GROUP BY + u.home_state +--- +integration_test: + name: query_dimensions_only + description: Query multiple dimensions without metrics + model: SIMPLE_MODEL + group_bys: ["ds__day", "user__home_state"] + check_query: | + SELECT + u.home_state AS user__home_state + , u.ds AS ds__day + FROM {{ source_schema }}.dim_users u + GROUP BY + u.ds + , u.home_state +--- +integration_test: + name: query_dimensions_from_different_tables + description: Query multiple dimensions without metrics, requiring a join + model: SIMPLE_MODEL + group_bys: ["user__home_state_latest", "listing__is_lux_latest"] + check_query: | + SELECT + u.home_state_latest AS user__home_state_latest + , l.is_lux AS listing__is_lux_latest + FROM {{ source_schema }}.dim_listings_latest l + LEFT OUTER JOIN {{ source_schema }}.dim_users_latest u + ON u.user_id = l.user_id + GROUP BY + u.home_state_latest + , l.is_lux +--- +integration_test: + name: query_time_dimension_without_granularity + description: Query just a time dimension, no granularity specified. Should assume default granularity for dimension. + model: SIMPLE_MODEL + group_bys: [ "verification__ds"] + check_query: | + SELECT + v.ds as verification__ds__day + FROM {{ source_schema }}.fct_id_verifications v + GROUP BY + v.ds +--- +integration_test: + name: query_non_default_time_dimension_without_granularity + description: Query just a time dimension, no granularity specified. Should assume default granularity for dimension. + model: EXTENDED_DATE_MODEL + group_bys: [ "monthly_ds"] + check_query: | + SELECT + ds AS monthly_ds__month + FROM {{ source_schema }}.fct_bookings_extended_monthly + GROUP BY + ds +--- +integration_test: + name: query_dimension_only_with_constraint + description: Query dimenension only + model: SIMPLE_MODEL + group_bys: ["user__home_state"] + where_filter: "{{ render_dimension_template('user__home_state') }} = 'CA'" + check_query: | + SELECT + u.home_state AS user__home_state + FROM {{ source_schema }}.dim_users u + WHERE user__home_state = 'CA' + GROUP BY + u.home_state diff --git a/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py b/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py index 71d78e45fe..16e3f696ba 100644 --- a/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py +++ b/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py @@ -1249,16 +1249,34 @@ def test_distinct_values( # noqa: D mf_test_session_state: MetricFlowTestSessionState, dataflow_plan_builder: DataflowPlanBuilder, dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, + column_association_resolver: ColumnAssociationResolver, sql_client: SqlClient, ) -> None: """Tests a plan to get distinct values for a dimension.""" dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values( - metric_specs=(MetricSpec(element_name="bookings"),), - dimension_spec=DimensionSpec( - element_name="country_latest", - entity_links=(EntityReference(element_name="listing"),), - ), - limit=100, + query_spec=MetricFlowQuerySpec( + dimension_specs=( + DimensionSpec(element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)), + ), + where_constraint=( + WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter( + PydanticWhereFilter( + where_sql_template="{{ Dimension('listing__country_latest') }} = 'us'", + ) + ) + ), + order_by_specs=( + OrderBySpec( + instance_spec=DimensionSpec( + element_name="country_latest", entity_links=(EntityReference(element_name="listing"),) + ), + descending=True, + ), + ), + limit=100, + ) ) convert_and_check( diff --git a/metricflow/test/query/test_query_parser.py b/metricflow/test/query/test_query_parser.py index 19ece527b4..4c2bc78680 100644 --- a/metricflow/test/query/test_query_parser.py +++ b/metricflow/test/query/test_query_parser.py @@ -249,6 +249,9 @@ def test_query_parser_case_insensitivity(bookings_query_parser: MetricFlowQueryP ), ) + with pytest.raises(UnableToSatisfyQueryError): + bookings_query_parser.parse_and_validate_query(group_by_names=["random_stuff"]) + def test_query_parser_with_object_params(bookings_query_parser: MetricFlowQueryParser) -> None: # noqa: D Metric = namedtuple("Metric", ["name", "descending"]) diff --git a/metricflow/test/time/test_time_granularity_solver.py b/metricflow/test/time/test_time_granularity_solver.py index 025778df10..0fcdf0ce28 100644 --- a/metricflow/test/time/test_time_granularity_solver.py +++ b/metricflow/test/time/test_time_granularity_solver.py @@ -6,9 +6,11 @@ from dbt_semantic_interfaces.references import MetricReference from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataset.dataset import DataSet from metricflow.filters.time_constraint import TimeRangeConstraint from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup +from metricflow.test.fixtures.model_fixtures import ConsistentIdObjectRepository from metricflow.test.time.metric_time_dimension import MTD_SPEC_DAY, MTD_SPEC_MONTH from metricflow.time.time_granularity_solver import ( PartialTimeDimensionSpec, @@ -89,19 +91,31 @@ def test_validate_day_granularity_for_day_and_month_metric( # noqa: D PARTIAL_PTD_SPEC = PartialTimeDimensionSpec(element_name=DataSet.metric_time_dimension_name(), entity_links=()) -def test_granularity_solution_for_day_metric(time_granularity_solver: TimeGranularitySolver) -> None: # noqa: D +def test_granularity_solution_for_day_metric( # noqa: D + time_granularity_solver: TimeGranularitySolver, + node_output_resolver: DataflowPlanNodeOutputDataSetResolver, + consistent_id_object_repository: ConsistentIdObjectRepository, +) -> None: assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs( metric_references=[MetricReference(element_name="bookings")], partial_time_dimension_specs=[PARTIAL_PTD_SPEC], + node_output_resolver=node_output_resolver, + read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()), ) == { PARTIAL_PTD_SPEC: MTD_SPEC_DAY, } -def test_granularity_solution_for_month_metric(time_granularity_solver: TimeGranularitySolver) -> None: # noqa: D +def test_granularity_solution_for_month_metric( # noqa: D + time_granularity_solver: TimeGranularitySolver, + node_output_resolver: DataflowPlanNodeOutputDataSetResolver, + consistent_id_object_repository: ConsistentIdObjectRepository, +) -> None: assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs( metric_references=[MetricReference(element_name="bookings_monthly")], partial_time_dimension_specs=[PARTIAL_PTD_SPEC], + node_output_resolver=node_output_resolver, + read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()), ) == { PARTIAL_PTD_SPEC: MTD_SPEC_MONTH, } @@ -109,10 +123,14 @@ def test_granularity_solution_for_month_metric(time_granularity_solver: TimeGran def test_granularity_solution_for_day_and_month_metrics( # noqa: D time_granularity_solver: TimeGranularitySolver, + node_output_resolver: DataflowPlanNodeOutputDataSetResolver, + consistent_id_object_repository: ConsistentIdObjectRepository, ) -> None: assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs( metric_references=[MetricReference(element_name="bookings"), MetricReference(element_name="bookings_monthly")], partial_time_dimension_specs=[PARTIAL_PTD_SPEC], + node_output_resolver=node_output_resolver, + read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()), ) == {PARTIAL_PTD_SPEC: MTD_SPEC_MONTH}