diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index c8ea1cc50..a1f6c13b7 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -418,12 +418,10 @@ def engine_adapter(self) -> EngineAdapter: @property def snapshot_evaluator(self) -> SnapshotEvaluator: if not self._snapshot_evaluator: - if snapshot_gateways := self._snapshot_gateways: - self._create_engine_adapters(set(snapshot_gateways.values())) self._snapshot_evaluator = SnapshotEvaluator( { gateway: adapter.with_log_level(logging.INFO) - for gateway, adapter in self._engine_adapters.items() + for gateway, adapter in self.engine_adapters.items() }, ddl_concurrent_tasks=self.concurrent_tasks, selected_gateway=self.selected_gateway, @@ -2025,21 +2023,13 @@ def _snapshot_gateways(self) -> t.Dict[str, str]: @cached_property def engine_adapters(self) -> t.Dict[str, EngineAdapter]: - """Returns all engine adapters for the gateways defined in the configs.""" - self._create_engine_adapters() - return self._engine_adapters - - def _create_engine_adapters(self, gateways: t.Optional[t.Set] = None) -> None: - """Create engine adapters for the gateways, when none provided include all defined in the configs.""" - + """Returns all the engine adapters for the gateways defined in the configuration.""" for gateway_name in self.config.gateways: - if gateway_name != self.selected_gateway and ( - gateways is None or gateway_name in gateways - ): + if gateway_name != self.selected_gateway: connection = self.config.get_connection(gateway_name) adapter = connection.create_engine_adapter() - self.concurrent_tasks = min(self.concurrent_tasks, connection.concurrent_tasks) self._engine_adapters[gateway_name] = adapter + return self._engine_adapters def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: if gateway: diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py index b1826de25..80b926783 100644 --- a/tests/core/engine_adapter/integration/__init__.py +++ b/tests/core/engine_adapter/integration/__init__.py @@ -470,6 +470,7 @@ def create_context( ], ) if config_mutator: + config.gateways = {self.gateway: config.gateways[self.gateway]} config_mutator(self.gateway, config) gateway_config = config.gateways[self.gateway] diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 1308e4a79..3f5dcdbaf 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -1328,6 +1328,8 @@ def test_sushi(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory): personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()], ) + # To enable parallelism in integration tests + config.gateways = {ctx.gateway: config.gateways[ctx.gateway]} current_gateway_config = config.gateways[ctx.gateway] current_gateway_config.state_schema = sushi_state_schema @@ -1731,6 +1733,8 @@ def _normalize_snowflake(name: str, prefix_regex: str = "(sqlmesh__)(.*)"): if config.model_defaults.dialect != ctx.dialect: config.model_defaults = config.model_defaults.copy(update={"dialect": ctx.dialect}) + # To enable parallelism in integration tests + config.gateways = {ctx.gateway: config.gateways[ctx.gateway]} current_gateway_config = config.gateways[ctx.gateway] if ctx.dialect == "athena": diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 7d10a7f37..42b234727 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -757,10 +757,8 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture): new_callable=mocker.PropertyMock(return_value={"snapshot": "athena"}), ) - ctx._create_engine_adapters() - assert isinstance(ctx._connection_config, RedshiftConnectionConfig) - assert len(ctx._engine_adapters) == 2 - assert isinstance(ctx._engine_adapters["athena"], AthenaEngineAdapter) - assert isinstance(ctx._engine_adapters["redshift"], RedshiftEngineAdapter) + assert len(ctx.engine_adapters) == 2 + assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter) + assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter) assert ctx.engine_adapter == ctx._get_engine_adapter("redshift") diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 9c161ae3f..d32f3026c 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -341,8 +341,7 @@ def test_gateway_specific_adapters(copy_to_temp_path, mocker): ctx = Context(paths=path, config="isolated_systems_config") - ctx._create_engine_adapters({"test"}) - assert len(ctx._engine_adapters) == 2 + assert len(ctx.engine_adapters) == 3 assert ctx.engine_adapter == ctx._get_engine_adapter() assert ctx._get_engine_adapter("test") == ctx._engine_adapters["test"] diff --git a/tests/core/test_model.py b/tests/core/test_model.py index bad0d80eb..7ac67cb35 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -6635,7 +6635,6 @@ def test_gateway_specific_render(assert_exp_eq) -> None: ) context = Context(config=config) assert context.engine_adapter == context._engine_adapters["main"] - assert len(context._engine_adapters) == 1 @model( name="dummy_model", @@ -6652,8 +6651,6 @@ def dummy_model_entry(evaluator: MacroEvaluator) -> exp.Select: assert isinstance(dummy_model, SqlModel) assert dummy_model.gateway == "duckdb" - # Calling render with a model with a non-default gateway should create - # the engine adapters and render with the model specified gateway assert_exp_eq( context.render("dummy_model"), """ diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 956fde41f..bf56686af 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -2021,7 +2021,6 @@ def test_test_with_gateway_specific_model(tmp_path: Path, mocker: MockerFixture) ) assert context.engine_adapter == context._engine_adapters["main"] - assert len(context._engine_adapters) == 1 with pytest.raises( SQLMeshError, match=r"Gateway 'wrong' not found in the available engine adapters." ):