diff --git a/dbt_meshify/utilities/contractor.py b/dbt_meshify/utilities/contractor.py index 67fc4d9..f8a784d 100644 --- a/dbt_meshify/utilities/contractor.py +++ b/dbt_meshify/utilities/contractor.py @@ -15,11 +15,17 @@ def generate_contract(self, model: ModelNode) -> ResourceChange: """Generate a ChangeSet that adds a contract to a Model.""" model_catalog = self.project.get_catalog_entry(model.unique_id) + # create a mapping of the column name to its representation in the yml file to maintain the original case + original_case = {column_name.lower(): column_name for column_name in model.columns.keys()} + if not model_catalog or not model_catalog.columns: columns = None else: columns = [ - {"name": name.lower(), "data_type": value.type.lower()} + { + "name": original_case.get(name.lower()) or name.lower(), + "data_type": value.type.lower(), + } for name, value in model_catalog.columns.items() ] diff --git a/tests/sql_and_yml_fixtures.py b/tests/sql_and_yml_fixtures.py index e6b920e..0517af4 100644 --- a/tests/sql_and_yml_fixtures.py +++ b/tests/sql_and_yml_fixtures.py @@ -56,6 +56,18 @@ - name: colleague description: "this is the colleague column" """ + +model_yml_all_col_all_caps = """ +models: + - name: shared_model + description: "this is a test model" + columns: + - name: ID + description: "this is the id column" + - name: COLLEAGUE + description: "this is the colleague column" +""" + expected_contract_yml_no_col = """ models: - name: shared_model @@ -138,6 +150,22 @@ data_type: varchar """ +expected_contract_yml_all_col_all_caps = """ +models: + - name: shared_model + description: "this is a test model" + config: + contract: + enforced: true + columns: + - name: ID + description: "this is the id column" + data_type: integer + - name: COLLEAGUE + description: "this is the colleague column" + data_type: varchar +""" + expected_contract_yml_no_entry = """ models: - name: shared_model diff --git a/tests/unit/test_add_contract_to_yml.py b/tests/unit/test_add_contract_to_yml.py index 9985503..02c516f 100644 --- a/tests/unit/test_add_contract_to_yml.py +++ b/tests/unit/test_add_contract_to_yml.py @@ -3,7 +3,7 @@ import pytest from dbt.contracts.files import FileHash -from dbt.contracts.graph.nodes import ModelNode, NodeType +from dbt.contracts.graph.nodes import ColumnInfo, ModelNode, NodeType from dbt_meshify.change import ResourceChange from dbt_meshify.storage.file_content_editors import ResourceFileEditor @@ -12,12 +12,14 @@ from ..dbt_project_fixtures import shared_model_catalog_entry from ..sql_and_yml_fixtures import ( expected_contract_yml_all_col, + expected_contract_yml_all_col_all_caps, expected_contract_yml_no_col, expected_contract_yml_no_entry, expected_contract_yml_one_col, expected_contract_yml_one_col_one_test, expected_contract_yml_other_model, model_yml_all_col, + model_yml_all_col_all_caps, model_yml_no_col_no_version, model_yml_one_col, model_yml_one_col_one_test, @@ -54,12 +56,41 @@ def model() -> ModelNode: ) +@pytest.fixture +def all_cap_model() -> ModelNode: + return ModelNode( + database=None, + resource_type=NodeType.Model, + checksum=FileHash("foo", "foo"), + schema="foo", + name=model_name, + package_name="foo", + path="models/_models.yml", + original_file_path=f"models/{model_name}.sql", + unique_id="model.foo.foo", + columns={ + "ID": ColumnInfo.from_dict({"name": "ID", "description": "this is the id column"}), + "COLLEAGUE": ColumnInfo.from_dict( + {"name": "COLLEAGUE", "description": "this is the colleague column"} + ), + }, + fqn=["foo", "foo"], + alias="foo", + ) + + @pytest.fixture def change(project, model) -> ResourceChange: contractor = Contractor(project=project) return contractor.generate_contract(model) +@pytest.fixture +def all_cap_model_change(project, all_cap_model) -> ResourceChange: + contractor = Contractor(project=project) + return contractor.generate_contract(all_cap_model) + + class TestAddContractToYML: def test_add_contract_to_yml_no_col(self, change): yml_dict = ResourceFileEditor.update_resource( @@ -86,3 +117,9 @@ def test_add_contract_to_yml_no_entry(self, change): def test_add_contract_to_yml_other_model(self, change): yml_dict = ResourceFileEditor.update_resource(read_yml(model_yml_other_model), change) assert yml_dict == read_yml(expected_contract_yml_other_model) + + def test_add_contract_to_yml_all_caps_columns(self, all_cap_model_change): + yml_dict = ResourceFileEditor.update_resource( + read_yml(model_yml_all_col_all_caps), all_cap_model_change + ) + assert yml_dict == read_yml(expected_contract_yml_all_col_all_caps)