From 3a89d9e961c0916c8102885a901993a984e1bdd5 Mon Sep 17 00:00:00 2001 From: Vaggelis Danias Date: Wed, 18 Dec 2024 19:20:43 +0200 Subject: [PATCH] Feat!: Support 'optimize' flag in model defs (#3512) Co-authored-by: Iaroslav Zeigerman --- docs/concepts/models/overview.md | 7 ++ docs/reference/model_configuration.md | 1 + sqlmesh/core/config/model.py | 2 + sqlmesh/core/model/common.py | 1 + sqlmesh/core/model/definition.py | 9 ++ sqlmesh/core/model/meta.py | 1 + sqlmesh/core/renderer.py | 4 + .../migrations/v0065_add_model_optimize.py | 5 + tests/core/test_model.py | 91 ++++++++++++++++++- tests/core/test_snapshot.py | 7 +- tests/core/test_snapshot_evaluator.py | 11 +-- tests/core/test_state_sync.py | 5 +- tests/schedulers/airflow/test_client.py | 4 +- 13 files changed, 131 insertions(+), 17 deletions(-) create mode 100644 sqlmesh/migrations/v0065_add_model_optimize.py diff --git a/docs/concepts/models/overview.md b/docs/concepts/models/overview.md index 2c40df761..f35889720 100644 --- a/docs/concepts/models/overview.md +++ b/docs/concepts/models/overview.md @@ -354,6 +354,13 @@ Learn more about these properties and their default values in the [model configu ### gateway : Specifies the gateway to use for the execution of this model. When not specified, the default gateway is used. +### optimize_query +: Whether the model's query should be optimized. All SQL models are optimized by default. Setting this +to `false` causes SQLMesh to disable query canonicalization & simplification. This should be turned off only if the optimized query leads to errors such as surpassing text limit. + +!!! warning + Turning off the optimizer may prevent column-level lineage from working for the affected model and its descendants, unless all columns in the model's query are qualified and it contains no star projections (e.g. `SELECT *`). + ## Incremental Model Properties These properties can be specified in an incremental model's `kind` definition. diff --git a/docs/reference/model_configuration.md b/docs/reference/model_configuration.md index 5c7288509..20965fea2 100644 --- a/docs/reference/model_configuration.md +++ b/docs/reference/model_configuration.md @@ -38,6 +38,7 @@ Configuration options for SQLMesh model properties. Supported by all model kinds | `allow_partials` | Whether this model can process partial (incomplete) data intervals | bool | N | | `enabled` | Whether the model is enabled. This attribute is `true` by default. Setting it to `false` causes SQLMesh to ignore this model when loading the project. | bool | N | | `gateway` | Specifies the gateway to use for the execution of this model. When not specified, the default gateway is used. | str | N | +| `optimize_query` | Whether the model's query should be optimized. This attribute is `true` by default. Setting it to `false` causes SQLMesh to disable query canonicalization & simplification. This should be turned off only if the optimized query leads to errors such as surpassing text limit. | bool | N | ### Model defaults diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index 2fac33ff3..3f38e6ad4 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -33,6 +33,7 @@ class ModelDefaultsConfig(BaseConfig): (eg. 'parquet', 'orc') on_destructive_change: What should happen when a forward-only model requires a destructive schema change. audits: The audits to be applied globally to all models in the project. + optimize_query: Whether the SQL models should be optimized """ kind: t.Optional[ModelKind] = None @@ -45,6 +46,7 @@ class ModelDefaultsConfig(BaseConfig): on_destructive_change: t.Optional[OnDestructiveChange] = None session_properties: t.Optional[t.Dict[str, t.Any]] = None audits: t.Optional[t.List[FunctionCall]] = None + optimize_query: t.Optional[bool] = None _model_kind_validator = model_kind_validator _on_destructive_change_validator = on_destructive_change_validator diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 7dca7661e..47ae1f73b 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -301,6 +301,7 @@ def depends_on(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[ "insert_overwrite", "allow_partials", "enabled", + "optimize_query", mode="before", check_fields=False, )(parse_bool) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index b00fc97aa..3c89177db 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -216,6 +216,7 @@ def render_definition( "default_catalog", "enabled", "inline_audits", + "optimize_query", ): expressions.append( exp.Property( @@ -840,6 +841,12 @@ def validate_definition(self) -> None: self._path, ) + if not self.is_sql and self.optimize_query is not None: + raise_config_error( + "SQLMesh query optimizer can only be enabled/disabled for SQL models", + self._path, + ) + def is_breaking_change(self, previous: Model) -> t.Optional[bool]: """Determines whether this model is a breaking change in relation to the `previous` model. @@ -881,6 +888,7 @@ def _data_hash_values(self) -> t.List[str]: self.physical_version, self.gateway, self.interval_unit.value if self.interval_unit is not None else None, + str(self.optimize_query) if self.optimize_query is not None else None, ] for column_name, column_type in (self.columns_to_types_ or {}).items(): @@ -1269,6 +1277,7 @@ def _query_renderer(self) -> QueryRenderer: only_execution_time=self.kind.only_execution_time, default_catalog=self.default_catalog, quote_identifiers=not no_quote_identifiers, + optimize_query=self.optimize_query, ) @property diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index c8215df23..2e5a09a02 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -78,6 +78,7 @@ class ModelMeta(_Node): enabled: bool = True physical_version: t.Optional[str] = None gateway: t.Optional[str] = None + optimize_query: t.Optional[bool] = None _bool_validator = bool_validator _model_kind_validator = model_kind_validator diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 5f1b651e9..c6e564bbf 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -51,6 +51,7 @@ def __init__( quote_identifiers: bool = True, model_fqn: t.Optional[str] = None, normalize_identifiers: bool = True, + optimize_query: t.Optional[bool] = True, ): self._expression = expression self._dialect = dialect @@ -65,6 +66,7 @@ def __init__( self.update_schema({} if schema is None else schema) self._cache: t.List[t.Optional[exp.Expression]] = [] self._model_fqn = model_fqn + self._optimize_query_flag = optimize_query is not False def update_schema(self, schema: t.Dict[str, t.Any]) -> None: self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect) @@ -438,6 +440,8 @@ def render( runtime_stage, start, end, execution_time, *kwargs.values() ) + needs_optimization = needs_optimization and self._optimize_query_flag + if should_cache and self._optimized_cache: query = self._optimized_cache else: diff --git a/sqlmesh/migrations/v0065_add_model_optimize.py b/sqlmesh/migrations/v0065_add_model_optimize.py new file mode 100644 index 000000000..cf6eaa403 --- /dev/null +++ b/sqlmesh/migrations/v0065_add_model_optimize.py @@ -0,0 +1,5 @@ +"""Add the optimize_query model attribute.""" + + +def migrate(state_sync, **kwargs): # type: ignore + pass diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 2bbd56b37..ace5d552d 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -4072,7 +4072,7 @@ def test_default_catalog_sql(assert_exp_eq): The system is not designed to actually support having an engine that doesn't support default catalog to start supporting it or the reverse of that. If that did happen then bugs would occur. """ - HASH_WITH_CATALOG = "368216481" + HASH_WITH_CATALOG = "516937963" # Test setting default catalog doesn't change hash if it matches existing logic expressions = d.parse( @@ -4238,7 +4238,7 @@ def test_default_catalog_sql(assert_exp_eq): def test_default_catalog_python(): - HASH_WITH_CATALOG = "663490914" + HASH_WITH_CATALOG = "770057346" @model(name="db.table", kind="full", columns={'"COL"': "int"}) def my_model(context, **kwargs): @@ -4330,7 +4330,7 @@ def test_default_catalog_external_model(): Since external models fqns are the only thing affected by default catalog, and when they change new snapshots are made, the hash will be the same across different names. """ - EXPECTED_HASH = "627688262" + EXPECTED_HASH = "3614876346" model = create_external_model("db.table", columns={"a": "int", "limit": "int"}) assert model.default_catalog is None @@ -6319,3 +6319,88 @@ def assert_metadata_only(): model = load_sql_based_model(expressions, signal_definitions=signal.get_registry()) model.signals.clear() assert_metadata_only() + + +def test_model_optimize(tmp_path: Path, assert_exp_eq): + defaults = [ + ModelDefaultsConfig(optimize_query=True).dict(), + ModelDefaultsConfig(optimize_query=False).dict(), + ] + non_optimized_sql = 'SELECT 1 + 2 AS "new_col"' + optimized_sql = 'SELECT 3 AS "new_col"' + + # Model flag is False, overriding defaults + disabled_opt = d.parse( + """ + MODEL ( + name test, + optimize_query False, + ); + + SELECT 1 + 2 AS new_col + """ + ) + + for default in defaults: + model = load_sql_based_model(disabled_opt, defaults=default) + assert_exp_eq(model.render_query(), non_optimized_sql) + + # Model flag is True, overriding defaults + enabled_opt = d.parse( + """ + MODEL ( + name test, + optimize_query True, + ); + + SELECT 1 + 2 AS new_col + """ + ) + + for default in defaults: + model = load_sql_based_model(enabled_opt, defaults=default) + assert_exp_eq(model.render_query(), optimized_sql) + + # Model flag is not defined, behavior is set according to the defaults + none_opt = d.parse( + """ + MODEL ( + name test, + ); + + SELECT 1 + 2 AS new_col + """ + ) + + assert_exp_eq(load_sql_based_model(none_opt).render_query(), optimized_sql) + assert_exp_eq( + load_sql_based_model(none_opt, defaults=defaults[0]).render_query(), optimized_sql + ) + assert_exp_eq( + load_sql_based_model(none_opt, defaults=defaults[1]).render_query(), non_optimized_sql + ) + + # Ensure that plan works as expected (optimize_query flag affects the model's data hash) + for parsed_model in [enabled_opt, disabled_opt, none_opt]: + context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))) + context.upsert_model(load_sql_based_model(parsed_model)) + context.plan(auto_apply=True, no_prompts=True) + + # Ensure non-SQLModels raise if optimize_query is not None + with pytest.raises( + ConfigError, + match=r"SQLMesh query optimizer can only be enabled/disabled for SQL models", + ): + seed_path = tmp_path / "seed.csv" + model_kind = SeedKind(path=str(seed_path.absolute())) + with open(seed_path, "w", encoding="utf-8") as fd: + fd.write( + """ + col_a,col_b,col_c + 1,text_a,1.0 + 2,text_b,2.0""" + ) + model = create_seed_model("test_db.test_seed_model", model_kind, optimize_query=True) + context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))) + context.upsert_model(model) + context.plan(auto_apply=True, no_prompts=True) diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 7983037cf..ed8d3590d 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -707,12 +707,11 @@ def test_fingerprint(model: Model, parent_model: Model): fingerprint = fingerprint_from_node(model, nodes={}) original_fingerprint = SnapshotFingerprint( - data_hash="2098818222", + data_hash="1312415267", metadata_hash="2793463216", ) assert fingerprint == original_fingerprint - with_parent_fingerprint = fingerprint_from_node(model, nodes={'"parent"."tbl"': parent_model}) assert with_parent_fingerprint != fingerprint assert int(with_parent_fingerprint.parent_data_hash) > 0 @@ -767,7 +766,7 @@ def test_fingerprint_seed_model(): ) expected_fingerprint = SnapshotFingerprint( - data_hash="295987232", + data_hash="1909791099", metadata_hash="3403817841", ) @@ -806,7 +805,7 @@ def test_fingerprint_jinja_macros(model: Model): } ) original_fingerprint = SnapshotFingerprint( - data_hash="979797026", + data_hash="923305614", metadata_hash="2793463216", ) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 3e5b648e1..35a0abf60 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -561,7 +561,7 @@ def test_evaluate_materialized_view_with_partitioned_by_cluster_by( [ call("CREATE SCHEMA IF NOT EXISTS `sqlmesh__test_schema`"), call( - "CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__3674208014` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`" + f"CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__{snapshot.version}` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`" ), ] ) @@ -2862,12 +2862,9 @@ def test_cleanup_managed(adapter_mock, make_snapshot, mocker: MockerFixture): evaluator.cleanup(target_snapshots=[cleanup_task]) - adapter_mock.drop_table.assert_called_once_with( - "sqlmesh__test_schema.test_schema__test_model__14485873__temp" - ) - adapter_mock.drop_managed_table.assert_called_once_with( - "sqlmesh__test_schema.test_schema__test_model__14485873" - ) + physical_name = f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}" + adapter_mock.drop_table.assert_called_once_with(f"{physical_name}__temp") + adapter_mock.drop_managed_table.assert_called_once_with(f"{physical_name}") def test_create_managed_forward_only_with_previous_version_doesnt_clone_for_dev_preview( diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index fe4c7700f..830013ec0 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -2221,16 +2221,17 @@ def test_snapshot_batching(state_sync, mocker, make_snapshot): ) ) calls = mock.delete_from.call_args_list + identifiers = sorted([snapshot_a.identifier, snapshot_b.identifier, snapshot_c.identifier]) assert mock.delete_from.call_args_list == [ call( exp.to_table("sqlmesh._snapshots"), where=parse_one( - f"(name, identifier) in (('\"a\"', '{snapshot_b.identifier}'), ('\"a\"', '{snapshot_a.identifier}'))" + f"(name, identifier) in (('\"a\"', '{identifiers[0]}'), ('\"a\"', '{identifiers[1]}'))" ), ), call( exp.to_table("sqlmesh._snapshots"), - where=parse_one(f"(name, identifier) in (('\"a\"', '{snapshot_c.identifier}'))"), + where=parse_one(f"(name, identifier) in (('\"a\"', '{identifiers[2]}'))"), ), ] diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py index 85534c37e..ebb356beb 100644 --- a/tests/schedulers/airflow/test_client.py +++ b/tests/schedulers/airflow/test_client.py @@ -184,7 +184,9 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot): "models_to_backfill": ['"test_model"'], "end_bounded": False, "ensure_finalized_snapshots": False, - "directly_modified_snapshots": [{"identifier": "1291282319", "name": '"test_model"'}], + "directly_modified_snapshots": [ + {"identifier": snapshot.identifier, "name": '"test_model"'} + ], "indirectly_modified_snapshots": {}, "removed_snapshots": [], "restatements": {