From fda1eb9c66eaa50d3ddd69e9c70786ac9eec85d3 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Fri, 13 Dec 2024 15:09:30 -0800 Subject: [PATCH] add tests and an error to make sure schema update are safe --- .../artifacts/schemas/manifest/v12/manifest.py | 10 ++++++++++ tests/functional/artifacts/test_artifacts.py | 16 ++++++++++++++++ tests/unit/contracts/graph/test_nodes.py | 10 ++++++---- tests/unit/contracts/graph/test_nodes_parsed.py | 8 ++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/core/dbt/artifacts/schemas/manifest/v12/manifest.py b/core/dbt/artifacts/schemas/manifest/v12/manifest.py index cc13fca43f5..e59ef02efee 100644 --- a/core/dbt/artifacts/schemas/manifest/v12/manifest.py +++ b/core/dbt/artifacts/schemas/manifest/v12/manifest.py @@ -180,3 +180,13 @@ def upgrade_schema_version(cls, data): if manifest_schema_version < cls.dbt_schema_version.version: data = upgrade_manifest_json(data, manifest_schema_version) return cls.from_dict(data) + + @classmethod + def validate(cls, _): + # When dbt try to load an artifact with additional optional fields + # that are not present in the schema, from_dict will work fine. + # As long as validate is not called, the schema will not be enforced. + # This is intentional, as it allows for safer schema upgrades. + raise RuntimeError( + "The WritableManifest should never be validated directly to allow for schema upgrades." + ) diff --git a/tests/functional/artifacts/test_artifacts.py b/tests/functional/artifacts/test_artifacts.py index 92ec8abe196..52cd0605524 100644 --- a/tests/functional/artifacts/test_artifacts.py +++ b/tests/functional/artifacts/test_artifacts.py @@ -663,6 +663,22 @@ def test_run_and_generate(self, project, manifest_schema_path, run_results_schem > 0 ) + # Test artifact with additional fields load fine + def test_load_artifact(self, project, manifest_schema_path, run_results_schema_path): + catcher = EventCatcher(ArtifactWritten) + results = run_dbt(args=["compile"], callbacks=[catcher.catch]) + assert len(results) == 7 + manifest_dct = get_artifact(os.path.join(project.project_root, "target", "manifest.json")) + # add a field that is not in the schema + for _, node in manifest_dct["nodes"].items(): + node["something_else"] = "something_else" + # load the manifest with the additional field + loaded_manifest = WritableManifest.from_dict(manifest_dct) + + # successfully loaded the manifest with the additional field, but the field should not be present + for _, node in loaded_manifest.nodes.items(): + assert not hasattr(node, "something_else") + class TestVerifyArtifactsReferences(BaseVerifyProject): @pytest.fixture(scope="class") diff --git a/tests/unit/contracts/graph/test_nodes.py b/tests/unit/contracts/graph/test_nodes.py index 3b509d0d20d..0648fe1174d 100644 --- a/tests/unit/contracts/graph/test_nodes.py +++ b/tests/unit/contracts/graph/test_nodes.py @@ -236,10 +236,12 @@ def test_basic_compiled_model(basic_compiled_dict, basic_compiled_model): assert node.is_ephemeral is False -def test_invalid_extra_fields_model(minimal_uncompiled_dict): - bad_extra = minimal_uncompiled_dict - bad_extra["notvalid"] = "nope" - assert_fails_validation(bad_extra, ModelNode) +def test_extra_fields_model_okay(minimal_uncompiled_dict): + extra = minimal_uncompiled_dict + extra["notvalid"] = "nope" + # Model still load fine with extra fields + loaded_model = ModelNode.from_dict(extra) + assert not hasattr(loaded_model, "notvalid") def test_invalid_bad_type_model(minimal_uncompiled_dict): diff --git a/tests/unit/contracts/graph/test_nodes_parsed.py b/tests/unit/contracts/graph/test_nodes_parsed.py index dc5a326f4d9..75d451b9956 100644 --- a/tests/unit/contracts/graph/test_nodes_parsed.py +++ b/tests/unit/contracts/graph/test_nodes_parsed.py @@ -1961,6 +1961,14 @@ def test_basic_source_definition( pickle.loads(pickle.dumps(node)) +def test_extra_fields_source_definition_okay(minimum_parsed_source_definition_dict): + extra = minimum_parsed_source_definition_dict + extra["notvalid"] = "nope" + # Model still load fine with extra fields + loaded_source = SourceDefinition.from_dict(extra) + assert not hasattr(loaded_source, "notvalid") + + def test_invalid_missing(minimum_parsed_source_definition_dict): bad_missing_name = minimum_parsed_source_definition_dict del bad_missing_name["name"]