diff --git a/metricflow/test/generate_snapshots.py b/metricflow/test/generate_snapshots.py index 82c5d294cf..642eebd890 100644 --- a/metricflow/test/generate_snapshots.py +++ b/metricflow/test/generate_snapshots.py @@ -41,13 +41,16 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.implementations.base import FrozenBaseModel -from dbt_semantic_interfaces.pretty_print import pformat_big_objects from metricflow.protocols.sql_client import SqlEngine +from metricflow.test.fixtures.setup_fixtures import SQL_ENGINE_SNAPSHOT_MARKER_NAME logger = logging.getLogger(__name__) +TEST_DIRECTORY = "metricflow/test" + + class MetricFlowTestCredentialSet(FrozenBaseModel): # noqa: D engine_url: Optional[str] engine_password: Optional[str] @@ -97,32 +100,6 @@ def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D ) -SNAPSHOT_GENERATING_TESTS = ( - "metricflow/test/cli/test_cli.py::test_saved_query", - "metricflow/test/cli/test_cli.py::test_saved_query_with_where", - "metricflow/test/cli/test_cli.py::test_saved_query_with_limit", - "metricflow/test/cli/test_cli.py::test_saved_query_explain", - "metricflow/test/dataflow/builder/test_dataflow_plan_builder.py", - "metricflow/test/dataflow/optimizer/source_scan/test_cm_branch_combiner.py", - "metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py", - "metricflow/test/dataset/test_convert_semantic_model.py", - "metricflow/test/integration/test_rendered_query.py", - "metricflow/test/integration/test_rendered_query.py", - "metricflow/test/model/test_data_warehouse_tasks.py", - "metricflow/test/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py", - "metricflow/test/plan_conversion/test_dataflow_to_execution.py", - "metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py", - "metricflow/test/sql/optimizer/test_column_pruner.py", - "metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py", - "metricflow/test/sql/optimizer/test_sub_query_reducer.py", - "metricflow/test/sql/optimizer/test_sub_query_reducer.py", - "metricflow/test/sql/optimizer/test_table_alias_simplifier.py", - "metricflow/test/sql/test_engine_specific_rendering.py", - "metricflow/test/sql/test_sql_plan_render.py", - "metricflow/test/sql/test_sql_plan_render.py", -) - - def run_command(command: str) -> None: # noqa: D logger.info(f"Running command {command}") return_code = os.system(command) @@ -130,8 +107,7 @@ def run_command(command: str) -> None: # noqa: D raise RuntimeError(f"Error running command: {command}") -def run_tests(test_configuration: MetricFlowTestConfiguration, test_file_paths: Sequence[str]) -> None: # noqa: D - combined_paths = " ".join(test_file_paths) +def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D if test_configuration.credential_set.engine_url is None: if "MF_SQL_ENGINE_URL" in os.environ: del os.environ["MF_SQL_ENGINE_URL"] @@ -146,7 +122,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration, test_file_paths: if test_configuration.engine is SqlEngine.DUCKDB: # Can't use --use-persistent-source-schema with duckdb since it's in memory. - run_command(f"pytest -x -vv -n 4 --overwrite-snapshots {combined_paths}") + run_command(f"pytest -x -vv -n 4 --overwrite-snapshots -m '{SQL_ENGINE_SNAPSHOT_MARKER_NAME}' {TEST_DIRECTORY}") elif ( test_configuration.engine is SqlEngine.REDSHIFT or test_configuration.engine is SqlEngine.SNOWFLAKE @@ -162,7 +138,8 @@ def run_tests(test_configuration: MetricFlowTestConfiguration, test_file_paths: f"hatch -v run {hatch_env}:pytest -x -vv -n 4 " f"--overwrite-snapshots" f"{' --use-persistent-source-schema' if use_persistent_source_schema else ''}" - f" {combined_paths}" + f"-m '{SQL_ENGINE_SNAPSHOT_MARKER_NAME}' " + f"{TEST_DIRECTORY}" ) else: assert_values_exhausted(test_configuration.engine) @@ -182,13 +159,13 @@ def run_cli() -> None: # noqa: D credential_sets = MetricFlowTestCredentialSetForAllEngines.parse_raw(credential_sets_json_str) - logger.info(f"Running the following tests to generate snapshots:\n{pformat_big_objects(SNAPSHOT_GENERATING_TESTS)}") + logger.info(f"Running tests in '{TEST_DIRECTORY}' with the marker '{SQL_ENGINE_SNAPSHOT_MARKER_NAME}'") for test_configuration in credential_sets.as_configurations: logger.info( f"Running tests for {test_configuration.engine} with URL: {test_configuration.credential_set.engine_url}" ) - run_tests(test_configuration, SNAPSHOT_GENERATING_TESTS) + run_tests(test_configuration) if __name__ == "__main__":