diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index cd327578c..3e5b648e1 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -3369,24 +3369,21 @@ def test_multi_engine_python_model_with_macros(adapters, make_snapshot): evaluator = SnapshotEvaluator(engine_adapters) @macro() - def create_index( + def validate_engine_call( evaluator: MacroEvaluator, - index_name: str, - model_name: str, - column: str, ): if evaluator.runtime_stage == "creating": - # To validate the model-specified gateway is used for the evaluator + # To validate the model-specified gateway is used for the macro evaluator evaluator.engine_adapter.get_catalog_type() - return f"CREATE INDEX IF NOT EXISTS {index_name} ON {model_name}({column});" + return None @model( "db.test_model", kind="full", gateway="secondary", columns={"id": "string", "name": "string"}, - pre_statements=["@CREATE_INDEX(idx, db.test_model, id)"], - post_statements=["@CREATE_INDEX(idx, db.test_model, id)"], + pre_statements=["@VALIDATE_ENGINE_CALL()"], + post_statements=["@VALIDATE_ENGINE_CALL()"], ) def model_with_statements(context, **kwargs): return pd.DataFrame( @@ -3406,7 +3403,7 @@ def model_with_statements(context, **kwargs): ) assert len(python_model.python_env) == 3 - assert isinstance(python_model.python_env["create_index"], Executable) + assert isinstance(python_model.python_env["validate_engine_call"], Executable) snapshot = make_snapshot(python_model) assert snapshot.model_gateway == "secondary" @@ -3429,14 +3426,8 @@ def model_with_statements(context, **kwargs): assert view_args[0][0][0] == "db__test_env.test_model" # For the pre/post statements verify the model-specific gateway was used - expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}" /* db.test_model */("id")' engine_adapters["default"].execute.assert_not_called() - call_args = engine_adapters["secondary"].execute.call_args_list - pre_calls = call_args[0][0][0] - assert pre_calls[0].sql(dialect="postgres") == expected_call - - post_calls = call_args[1][0][0] - assert post_calls[0].sql(dialect="postgres") == expected_call + assert len(engine_adapters["secondary"].execute.call_args_list) == 2 # Validate that the get_catalog_type method was called only on the secondary engine from the macro evaluator engine_adapters["default"].get_catalog_type.assert_not_called()