Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Add adapters for commands that are not using the snapshot evaluator #3531

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,12 +418,10 @@ def engine_adapter(self) -> EngineAdapter:
@property
def snapshot_evaluator(self) -> SnapshotEvaluator:
if not self._snapshot_evaluator:
if self._snapshot_gateways:
self._create_engine_adapters(set(self._snapshot_gateways.values()))
self._snapshot_evaluator = SnapshotEvaluator(
{
gateway: adapter.with_log_level(logging.INFO)
for gateway, adapter in self._engine_adapters.items()
for gateway, adapter in self.engine_adapters.items()
},
ddl_concurrent_tasks=self.concurrent_tasks,
selected_gateway=self.selected_gateway,
Expand Down Expand Up @@ -1476,6 +1474,7 @@ def table_diff(
source_alias, target_alias = source, target

adapter = self.engine_adapter

if model_or_snapshot:
model = self.get_model(model_or_snapshot, raise_if_missing=True)
adapter = self._get_engine_adapter(model.gateway)
Expand Down Expand Up @@ -1641,6 +1640,7 @@ def create_test(
test_adapter = self._test_connection_config.create_engine_adapter(
register_comments_override=False
)

generate_test(
model=model_to_test,
input_queries=input_queries,
Expand Down Expand Up @@ -2021,21 +2021,19 @@ def _snapshot_gateways(self) -> t.Dict[str, str]:
if snapshot.is_model and snapshot.model.gateway
}

def _create_engine_adapters(self, gateways: t.Optional[t.Set] = None) -> None:
"""Create engine adapters for the gateways, when none provided include all defined in the configs."""

@cached_property
def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
"""Returns all the engine adapters for the gateways defined in the configuration."""
for gateway_name in self.config.gateways:
if gateway_name != self.selected_gateway and (
gateways is None or gateway_name in gateways
):
if gateway_name != self.selected_gateway:
connection = self.config.get_connection(gateway_name)
adapter = connection.create_engine_adapter()
self.concurrent_tasks = min(self.concurrent_tasks, connection.concurrent_tasks)
self._engine_adapters[gateway_name] = adapter
return self._engine_adapters

def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
if gateway:
if adapter := self._engine_adapters.get(gateway):
if adapter := self.engine_adapters.get(gateway):
return adapter
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
return self.engine_adapter
Expand Down
1 change: 1 addition & 0 deletions tests/core/engine_adapter/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def create_context(
],
)
if config_mutator:
config.gateways = {self.gateway: config.gateways[self.gateway]}
config_mutator(self.gateway, config)

gateway_config = config.gateways[self.gateway]
Expand Down
4 changes: 4 additions & 0 deletions tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,8 @@ def test_sushi(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory):
personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()],
)

# To enable parallelism in integration tests
config.gateways = {ctx.gateway: config.gateways[ctx.gateway]}
current_gateway_config = config.gateways[ctx.gateway]
current_gateway_config.state_schema = sushi_state_schema

Expand Down Expand Up @@ -1731,6 +1733,8 @@ def _normalize_snowflake(name: str, prefix_regex: str = "(sqlmesh__)(.*)"):
if config.model_defaults.dialect != ctx.dialect:
config.model_defaults = config.model_defaults.copy(update={"dialect": ctx.dialect})

# To enable parallelism in integration tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this impact parallelism?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To isolate the connections per test because otherwise they will encounter concurrency issues, as they will end up using the same connections. Since all the engine adapters are created upfront and this particular file's tests are using for the context the same config: https://github.com/TobikoData/sqlmesh/blob/main/tests/core/engine_adapter/integration/config.yaml

config.gateways = {ctx.gateway: config.gateways[ctx.gateway]}
current_gateway_config = config.gateways[ctx.gateway]

if ctx.dialect == "athena":
Expand Down
8 changes: 3 additions & 5 deletions tests/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,10 +757,8 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
new_callable=mocker.PropertyMock(return_value={"snapshot": "athena"}),
)

ctx._create_engine_adapters()

assert isinstance(ctx._connection_config, RedshiftConnectionConfig)
assert len(ctx._engine_adapters) == 2
assert isinstance(ctx._engine_adapters["athena"], AthenaEngineAdapter)
assert isinstance(ctx._engine_adapters["redshift"], RedshiftEngineAdapter)
assert len(ctx.engine_adapters) == 2
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter)
assert ctx.engine_adapter == ctx._get_engine_adapter("redshift")
9 changes: 6 additions & 3 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,12 @@ def test_gateway_specific_adapters(copy_to_temp_path, mocker):
ctx = Context(paths=path, config="isolated_systems_config", gateway="prod")
assert len(ctx._engine_adapters) == 1
assert ctx.engine_adapter == ctx._engine_adapters["prod"]

with pytest.raises(SQLMeshError):
assert ctx._get_engine_adapter("dev")
assert ctx._get_engine_adapter("non_existing")

# This will create the requested engine adapter
assert ctx._get_engine_adapter("dev") == ctx._engine_adapters["dev"]

ctx = Context(paths=path, config="isolated_systems_config")
assert len(ctx._engine_adapters) == 1
Expand All @@ -337,8 +341,7 @@ def test_gateway_specific_adapters(copy_to_temp_path, mocker):

ctx = Context(paths=path, config="isolated_systems_config")

ctx._create_engine_adapters({"test"})
assert len(ctx._engine_adapters) == 2
assert len(ctx.engine_adapters) == 3
assert ctx.engine_adapter == ctx._get_engine_adapter()
assert ctx._get_engine_adapter("test") == ctx._engine_adapters["test"]

Expand Down
41 changes: 41 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sqlmesh.core.context import Context, ExecutionContext
from sqlmesh.core.dialect import parse
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
from sqlmesh.core.macros import MacroEvaluator, macro
from sqlmesh.core.model import (
CustomKind,
Expand Down Expand Up @@ -6620,3 +6621,43 @@ def test_auto_restatement():
)
with pytest.raises(ValueError, match="Invalid cron expression '@invalid'.*"):
load_sql_based_model(parsed_definition)


def test_gateway_specific_render(assert_exp_eq) -> None:
gateways = {
"main": GatewayConfig(connection=DuckDBConnectionConfig()),
"duckdb": GatewayConfig(connection=DuckDBConnectionConfig()),
}
config = Config(
gateways=gateways,
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
default_gateway="main",
)
context = Context(config=config)
assert context.engine_adapter == context._engine_adapters["main"]

@model(
name="dummy_model",
is_sql=True,
kind="full",
gateway="duckdb",
grain='"x"',
)
def dummy_model_entry(evaluator: MacroEvaluator) -> exp.Select:
return exp.select("x").from_(exp.values([("1", 2)], "_v", ["x"]))

dummy_model = model.get_registry()["dummy_model"].model(module_path=Path("."), path=Path("."))
context.upsert_model(dummy_model)
assert isinstance(dummy_model, SqlModel)
assert dummy_model.gateway == "duckdb"

assert_exp_eq(
context.render("dummy_model"),
"""
SELECT
"_v"."x" AS "x",
FROM (VALUES ('1', 2)) AS "_v"("x")
""",
)
assert isinstance(context._get_engine_adapter("duckdb"), DuckDBEngineAdapter)
assert len(context._engine_adapters) == 2
52 changes: 51 additions & 1 deletion tests/core/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sqlmesh.core.macros import MacroEvaluator, macro
from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model
from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest
from sqlmesh.utils.errors import ConfigError, TestError
from sqlmesh.utils.errors import ConfigError, SQLMeshError, TestError
from sqlmesh.utils.yaml import dump as dump_yaml
from sqlmesh.utils.yaml import load as load_yaml

Expand Down Expand Up @@ -1989,3 +1989,53 @@ def test_test_generation_with_recursive_ctes(tmp_path: Path) -> None:
}

_check_successful_or_raise(context.test())


def test_test_with_gateway_specific_model(tmp_path: Path, mocker: MockerFixture) -> None:
init_example_project(tmp_path, dialect="duckdb")

config = Config(
gateways={
"main": GatewayConfig(connection=DuckDBConnectionConfig()),
"second": GatewayConfig(connection=DuckDBConnectionConfig()),
},
default_gateway="main",
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
)
gw_model_sql_file = tmp_path / "models" / "gw_model.sql"

# The model has a gateway specified which isn't the default
gw_model_sql_file.write_text(
"MODEL (name sqlmesh_example.gw_model, gateway second); SELECT c FROM sqlmesh_example.input_model;"
)
input_model_sql_file = tmp_path / "models" / "input_model.sql"
input_model_sql_file.write_text(
"MODEL (name sqlmesh_example.input_model); SELECT c FROM external_table;"
)

context = Context(paths=tmp_path, config=config)
input_queries = {'"memory"."sqlmesh_example"."input_model"': "SELECT 5 AS c"}
mocker.patch(
"sqlmesh.core.engine_adapter.base.EngineAdapter.fetchdf",
return_value=pd.DataFrame({"c": [5]}),
)

assert context.engine_adapter == context._engine_adapters["main"]
with pytest.raises(
SQLMeshError, match=r"Gateway 'wrong' not found in the available engine adapters."
):
context._get_engine_adapter("wrong")

# Create test should use the gateway specific engine adapter
context.create_test("sqlmesh_example.gw_model", input_queries=input_queries, overwrite=True)
assert context._get_engine_adapter("second") == context._engine_adapters["second"]
assert len(context._engine_adapters) == 2

test = load_yaml(context.path / c.TESTS / "test_gw_model.yaml")

assert len(test) == 1
assert "test_gw_model" in test
assert test["test_gw_model"]["inputs"] == {
'"memory"."sqlmesh_example"."input_model"': [{"c": 5}]
}
assert test["test_gw_model"]["outputs"] == {"query": [{"c": 5}]}
Loading