Skip to content

Commit

Permalink
Write & update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Oct 11, 2023
1 parent 4f0d5d2 commit c9526e5
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 19 deletions.
1 change: 1 addition & 0 deletions metricflow/test/dataflow/builder/test_cyclic_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def cyclic_join_manifest_dataflow_plan_builder( # noqa: D

return DataflowPlanBuilder(
source_nodes=consistent_id_object_repository.cyclic_join_source_nodes,
read_nodes=list(consistent_id_object_repository.cyclic_join_read_nodes.values()),
semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup,
cost_function=DefaultCostFunction(),
)
Expand Down
78 changes: 72 additions & 6 deletions metricflow/test/dataflow/builder/test_dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,81 @@ def test_distinct_values_plan( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
column_association_resolver: ColumnAssociationResolver,
) -> None:
"""Tests a plan to get distinct values of a dimension."""
dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values(
metric_specs=(MetricSpec(element_name="bookings"),),
dimension_spec=DimensionSpec(
element_name="country_latest",
entity_links=(EntityReference(element_name="listing"),),
),
limit=100,
query_spec=MetricFlowQuerySpec(
dimension_specs=(
DimensionSpec(element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)),
),
where_constraint=(
WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(
PydanticWhereFilter(
where_sql_template="{{ Dimension('listing__country_latest') }} = 'us'",
)
)
),
order_by_specs=(
OrderBySpec(
instance_spec=DimensionSpec(
element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)
),
descending=True,
),
),
limit=100,
)
)

assert_plan_snapshot_text_equal(
request=request,
mf_test_session_state=mf_test_session_state,
plan=dataflow_plan,
plan_snapshot_text=dataflow_plan_as_text(dataflow_plan),
)

display_graph_if_requested(
request=request,
mf_test_session_state=mf_test_session_state,
dag_graph=dataflow_plan,
)


def test_distinct_values_plan_with_join( # noqa: D
request: FixtureRequest,
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
column_association_resolver: ColumnAssociationResolver,
) -> None:
"""Tests a plan to get distinct values of 2 dimensions, where a join is required."""
dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values(
query_spec=MetricFlowQuerySpec(
dimension_specs=(
DimensionSpec(element_name="home_state_latest", entity_links=(EntityReference(element_name="user"),)),
DimensionSpec(element_name="is_lux_latest", entity_links=(EntityReference(element_name="listing"),)),
),
where_constraint=(
WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(
PydanticWhereFilter(
where_sql_template="{{ Dimension('listing__country_latest') }} = 'us'",
)
)
),
order_by_specs=(
OrderBySpec(
instance_spec=DimensionSpec(
element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)
),
descending=True,
),
),
limit=100,
)
)

assert_plan_snapshot_text_equal(
Expand Down
4 changes: 2 additions & 2 deletions metricflow/test/dataflow/builder/test_node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from metricflow.dataset.dataset import DataSet
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver
from metricflow.plan_conversion.node_processor import PreDimensionJoinNodeProcessor
from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor
from metricflow.specs.specs import (
DimensionSpec,
EntityReference,
Expand Down Expand Up @@ -65,7 +65,7 @@ def make_multihop_node_evaluator(
semantic_manifest_lookup=semantic_manifest_lookup_with_multihop_links,
)

node_processor = PreDimensionJoinNodeProcessor(
node_processor = PreJoinNodeProcessor(
semantic_model_lookup=semantic_manifest_lookup_with_multihop_links.semantic_model_lookup,
node_data_set_resolver=node_data_set_resolver,
)
Expand Down
3 changes: 3 additions & 0 deletions metricflow/test/fixtures/dataflow_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def dataflow_plan_builder( # noqa: D
) -> DataflowPlanBuilder:
return DataflowPlanBuilder(
source_nodes=consistent_id_object_repository.simple_model_source_nodes,
read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
cost_function=DefaultCostFunction(),
)
Expand All @@ -47,6 +48,7 @@ def multihop_dataflow_plan_builder( # noqa: D
) -> DataflowPlanBuilder:
return DataflowPlanBuilder(
source_nodes=consistent_id_object_repository.multihop_model_source_nodes,
read_nodes=list(consistent_id_object_repository.multihop_model_read_nodes.values()),
semantic_manifest_lookup=multi_hop_join_semantic_manifest_lookup,
cost_function=DefaultCostFunction(),
)
Expand All @@ -68,6 +70,7 @@ def scd_dataflow_plan_builder( # noqa: D
) -> DataflowPlanBuilder:
return DataflowPlanBuilder(
source_nodes=consistent_id_object_repository.scd_model_source_nodes,
read_nodes=list(consistent_id_object_repository.scd_model_read_nodes.values()),
semantic_manifest_lookup=scd_semantic_manifest_lookup,
cost_function=DefaultCostFunction(),
column_association_resolver=scd_column_association_resolver,
Expand Down
15 changes: 13 additions & 2 deletions metricflow/test/fixtures/model_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def query_parser_from_yaml(yaml_contents: List[YamlConfigFile]) -> MetricFlowQue
).semantic_manifest
)
SemanticManifestValidator[SemanticManifest]().checked_validations(semantic_manifest_lookup.semantic_manifest)
source_nodes = _data_set_to_source_nodes(semantic_manifest_lookup, create_data_sets(semantic_manifest_lookup))
return MetricFlowQueryParser(
model=semantic_manifest_lookup,
column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup),
source_nodes=source_nodes,
read_nodes=list(_data_set_to_read_nodes(create_data_sets(semantic_manifest_lookup)).values()),
node_output_resolver=DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup),
semantic_manifest_lookup=semantic_manifest_lookup,
Expand All @@ -88,6 +87,7 @@ class ConsistentIdObjectRepository:
scd_model_read_nodes: OrderedDict[str, ReadSqlSourceNode]
scd_model_source_nodes: Sequence[BaseOutput]

cyclic_join_read_nodes: OrderedDict[str, ReadSqlSourceNode]
cyclic_join_source_nodes: Sequence[BaseOutput]


Expand Down Expand Up @@ -122,6 +122,7 @@ def consistent_id_object_repository(
scd_model_source_nodes=_data_set_to_source_nodes(
semantic_manifest_lookup=scd_semantic_manifest_lookup, data_sets=scd_data_sets
),
cyclic_join_read_nodes=_data_set_to_read_nodes(cyclic_join_data_sets),
cyclic_join_source_nodes=_data_set_to_source_nodes(
semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup, data_sets=cyclic_join_data_sets
),
Expand Down Expand Up @@ -239,3 +240,13 @@ def cyclic_join_semantic_manifest_lookup(template_mapping: Dict[str, str]) -> Se
"""Manifest that contains a potential cycle in the join graph (if not handled properly)."""
build_result = load_semantic_manifest("cyclic_join_manifest", template_mapping)
return SemanticManifestLookup(build_result.semantic_manifest)


@pytest.fixture(scope="session")
def node_output_resolver( # noqa:D
simple_semantic_manifest_lookup: SemanticManifestLookup,
) -> DataflowPlanNodeOutputDataSetResolver:
return DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=DunderColumnAssociationResolver(simple_semantic_manifest_lookup),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)
2 changes: 1 addition & 1 deletion metricflow/test/integration/configured_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ class Config: # noqa: D
name: str
# Name of the semantic model to use.
model: IntegrationTestModel
metrics: Tuple[str, ...]
# The SQL query that can be run to obtain the expected results.
check_query: str
file_path: str
metrics: Tuple[str, ...] = ()
group_bys: Tuple[str, ...] = ()
group_by_objs: Tuple[Dict, ...] = ()
order_bys: Tuple[str, ...] = ()
Expand Down
80 changes: 80 additions & 0 deletions metricflow/test/integration/test_cases/itest_dimensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,83 @@ integration_test:
GROUP BY
v.ds
, u.home_state
---
integration_test:
name: query_dimension_only
description: Query dimenension only
model: SIMPLE_MODEL
group_bys: ["user__home_state"]
check_query: |
SELECT
u.home_state AS user__home_state
FROM {{ source_schema }}.dim_users u
GROUP BY
u.home_state
---
integration_test:
name: query_dimensions_only
description: Query multiple dimensions without metrics
model: SIMPLE_MODEL
group_bys: ["ds__day", "user__home_state"]
check_query: |
SELECT
u.home_state AS user__home_state
, u.ds AS ds__day
FROM {{ source_schema }}.dim_users u
GROUP BY
u.ds
, u.home_state
---
integration_test:
name: query_dimensions_from_different_tables
description: Query multiple dimensions without metrics, requiring a join
model: SIMPLE_MODEL
group_bys: ["user__home_state_latest", "listing__is_lux_latest"]
check_query: |
SELECT
u.home_state_latest AS user__home_state_latest
, l.is_lux AS listing__is_lux_latest
FROM {{ source_schema }}.dim_listings_latest l
LEFT OUTER JOIN {{ source_schema }}.dim_users_latest u
ON u.user_id = l.user_id
GROUP BY
u.home_state_latest
, l.is_lux
---
integration_test:
name: query_time_dimension_without_granularity
description: Query just a time dimension, no granularity specified. Should assume default granularity for dimension.
model: SIMPLE_MODEL
group_bys: [ "verification__ds"]
check_query: |
SELECT
v.ds as verification__ds__day
FROM {{ source_schema }}.fct_id_verifications v
GROUP BY
v.ds
---
integration_test:
name: query_non_default_time_dimension_without_granularity
description: Query just a time dimension, no granularity specified. Should assume default granularity for dimension.
model: EXTENDED_DATE_MODEL
group_bys: [ "monthly_ds"]
check_query: |
SELECT
ds AS monthly_ds__month
FROM {{ source_schema }}.fct_bookings_extended_monthly
GROUP BY
ds
---
integration_test:
name: query_dimension_only_with_constraint
description: Query dimenension only
model: SIMPLE_MODEL
group_bys: ["user__home_state"]
where_filter: "{{ render_dimension_template('user__home_state') }} = 'CA'"
check_query: |
SELECT
u.home_state AS user__home_state
FROM {{ source_schema }}.dim_users u
WHERE user__home_state = 'CA'
GROUP BY
u.home_state
30 changes: 24 additions & 6 deletions metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,16 +1249,34 @@ def test_distinct_values( # noqa: D
mf_test_session_state: MetricFlowTestSessionState,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
column_association_resolver: ColumnAssociationResolver,
sql_client: SqlClient,
) -> None:
"""Tests a plan to get distinct values for a dimension."""
dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values(
metric_specs=(MetricSpec(element_name="bookings"),),
dimension_spec=DimensionSpec(
element_name="country_latest",
entity_links=(EntityReference(element_name="listing"),),
),
limit=100,
query_spec=MetricFlowQuerySpec(
dimension_specs=(
DimensionSpec(element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)),
),
where_constraint=(
WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(
PydanticWhereFilter(
where_sql_template="{{ Dimension('listing__country_latest') }} = 'us'",
)
)
),
order_by_specs=(
OrderBySpec(
instance_spec=DimensionSpec(
element_name="country_latest", entity_links=(EntityReference(element_name="listing"),)
),
descending=True,
),
),
limit=100,
)
)

convert_and_check(
Expand Down
3 changes: 3 additions & 0 deletions metricflow/test/query/test_query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def test_query_parser_case_insensitivity(bookings_query_parser: MetricFlowQueryP
),
)

with pytest.raises(UnableToSatisfyQueryError):
bookings_query_parser.parse_and_validate_query(group_by_names=["random_stuff"])


def test_query_parser_with_object_params(bookings_query_parser: MetricFlowQueryParser) -> None: # noqa: D
Metric = namedtuple("Metric", ["name", "descending"])
Expand Down
22 changes: 20 additions & 2 deletions metricflow/test/time/test_time_granularity_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from dbt_semantic_interfaces.references import MetricReference
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataset.dataset import DataSet
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.test.fixtures.model_fixtures import ConsistentIdObjectRepository
from metricflow.test.time.metric_time_dimension import MTD_SPEC_DAY, MTD_SPEC_MONTH
from metricflow.time.time_granularity_solver import (
PartialTimeDimensionSpec,
Expand Down Expand Up @@ -89,30 +91,46 @@ def test_validate_day_granularity_for_day_and_month_metric( # noqa: D
PARTIAL_PTD_SPEC = PartialTimeDimensionSpec(element_name=DataSet.metric_time_dimension_name(), entity_links=())


def test_granularity_solution_for_day_metric(time_granularity_solver: TimeGranularitySolver) -> None: # noqa: D
def test_granularity_solution_for_day_metric( # noqa: D
time_granularity_solver: TimeGranularitySolver,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
consistent_id_object_repository: ConsistentIdObjectRepository,
) -> None:
assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=[MetricReference(element_name="bookings")],
partial_time_dimension_specs=[PARTIAL_PTD_SPEC],
node_output_resolver=node_output_resolver,
read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()),
) == {
PARTIAL_PTD_SPEC: MTD_SPEC_DAY,
}


def test_granularity_solution_for_month_metric(time_granularity_solver: TimeGranularitySolver) -> None: # noqa: D
def test_granularity_solution_for_month_metric( # noqa: D
time_granularity_solver: TimeGranularitySolver,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
consistent_id_object_repository: ConsistentIdObjectRepository,
) -> None:
assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=[MetricReference(element_name="bookings_monthly")],
partial_time_dimension_specs=[PARTIAL_PTD_SPEC],
node_output_resolver=node_output_resolver,
read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()),
) == {
PARTIAL_PTD_SPEC: MTD_SPEC_MONTH,
}


def test_granularity_solution_for_day_and_month_metrics( # noqa: D
time_granularity_solver: TimeGranularitySolver,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
consistent_id_object_repository: ConsistentIdObjectRepository,
) -> None:
assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=[MetricReference(element_name="bookings"), MetricReference(element_name="bookings_monthly")],
partial_time_dimension_specs=[PARTIAL_PTD_SPEC],
node_output_resolver=node_output_resolver,
read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()),
) == {PARTIAL_PTD_SPEC: MTD_SPEC_MONTH}


Expand Down

0 comments on commit c9526e5

Please sign in to comment.