Skip to content

Commit

Permalink
Merge pull request #91 from dbt-labs/fix/duplicate-versions
Browse files Browse the repository at this point in the history
fix version bug when incrementing versions on a prereleased model
  • Loading branch information
graciegoheen authored Jul 10, 2023
2 parents 764ad1f + 65a351d commit 474617d
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 13 deletions.
24 changes: 16 additions & 8 deletions dbt_meshify/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@
logger.add(sys.stdout, format=log_format)


class FatalMeshifyException(click.ClickException):
def __init__(self, message):
super().__init__(message)

def show(self):
logger.error(self.message)
if self.__cause__ is not None:
logger.exception(self.__cause__)


# define cli group
@click.group()
def cli():
Expand Down Expand Up @@ -96,8 +106,7 @@ def split(project_name, select, exclude, project_path, selector, create_path):
subproject_creator.initialize()
logger.success(f"Successfully created subproject {subproject.name}")
except Exception as e:
logger.error(f"Error creating subproject {subproject.name}")
logger.exception(e)
raise FatalMeshifyException(f"Error creating subproject {subproject.name}")


@operation.command(name="add-contract")
Expand Down Expand Up @@ -134,8 +143,7 @@ def add_contract(select, exclude, project_path, selector, public_only=False):
meshify_constructor.add_model_contract()
logger.success(f"Successfully added contract to model: {model_unique_id}")
except Exception as e:
logger.error(f"Error adding contract to model: {model_unique_id}")
logger.exception(e)
raise FatalMeshifyException(f"Error adding contract to model: {model_unique_id}")


@operation.command(name="add-version")
Expand Down Expand Up @@ -168,8 +176,9 @@ def add_version(select, exclude, project_path, selector, prerelease, defined_in)
meshify_constructor.add_model_version(prerelease=prerelease, defined_in=defined_in)
logger.success(f"Successfully added version to model: {model_unique_id}")
except Exception as e:
logger.error(f"Error adding version to model: {model_unique_id}")
logger.exception(e)
raise FatalMeshifyException(
f"Error adding version to model: {model_unique_id}"
) from e


@operation.command(name="create-group")
Expand Down Expand Up @@ -230,8 +239,7 @@ def create_group(
)
logger.success(f"Successfully created group: {name}")
except Exception as e:
logger.error(f"Error creating group: {name}")
logger.exception(e)
raise FatalMeshifyException(f"Error creating group: {name}")


@cli.command(name="group")
Expand Down
26 changes: 22 additions & 4 deletions dbt_meshify/storage/file_content_editors.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,20 @@ def add_model_contract_to_yml(
models_yml["models"] = list(models.values())
return models_yml

def get_latest_yml_defined_version(self, model_yml: Dict[str, Any]):
"""
Returns the latest version defined in the yml file for a given model name
the format of `model_yml` should be a single model yml entry
if no versions, returns 0
"""
model_yml_versions = model_yml.get("versions", [])
try:
return max([int(v.get("v")) for v in model_yml_versions]) if model_yml_versions else 0
except ValueError:
raise ValueError(
f"Version not an integer, can't increment version for {model_yml.get('name')}"
)

def add_model_version_to_yml(
self,
model_name,
Expand All @@ -239,16 +253,17 @@ def add_model_version_to_yml(
# add the version to the model yml entry
versions_list = model_yml.get("versions") or []
latest_version = model_yml.get("latest_version") or 0
latest_yml_version = self.get_latest_yml_defined_version(model_yml)
version_dict: Dict[str, Union[int, str, os.PathLike]] = {}
if not versions_list:
version_dict["v"] = 1
latest_version += 1
# if the model has versions, add the next version
# if prerelease flag is true, do not increment the latest_version
elif prerelease:
version_dict = {"v": latest_version + 1}
version_dict = {"v": latest_yml_version + 1}
else:
version_dict = {"v": latest_version + 1}
version_dict = {"v": latest_yml_version + 1}
latest_version += 1
# add the defined_in key if it exists
if defined_in:
Expand Down Expand Up @@ -389,6 +404,9 @@ def add_model_version(
# read the yml file
# pass empty dict if no file contents returned
models_yml = self.file_manager.read_file(yml_path) or {}
latest_yml_version = self.get_latest_yml_defined_version(
resources_yml_to_dict(models_yml).get(self.node.name, {}) # type: ignore
)
try:
updated_yml = self.add_model_version_to_yml(
model_name=self.node.name,
Expand All @@ -410,7 +428,7 @@ def add_model_version(
next_version_file_name = (
f"{defined_in}.{self.node.language}"
if defined_in
else f"{self.node.name}_v{latest_version + 1}.{self.node.language}"
else f"{self.node.name}_v{latest_yml_version + 1}.{self.node.language}"
)
model_path = self.get_resource_path()

Expand All @@ -432,7 +450,7 @@ def add_model_version(
logger.info(f"Creating new version of {self.node.name} at {next_version_path}")
self.file_manager.write_file(next_version_path, self.node.raw_code)
# if the existing version doesn't use the _v{version} naming convention, rename it to the previous version
if not model_path.root.endswith(f"_v{latest_version}.{self.node.language}"):
if not model_path.stem.endswith(f"_v{latest_version}"):
logger.info(
f"Renaming existing version of {self.node.name} from {model_path.name} to {last_version_path.name}"
)
Expand Down
49 changes: 48 additions & 1 deletion tests/integration/test_version_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from dbt_meshify.main import add_version

from ..sql_and_yml_fixtures import (
expected_versioned_model_yml_increment_prerelease_version,
expected_versioned_model_yml_increment_prerelease_version_with_second_prerelease,
expected_versioned_model_yml_increment_version_defined_in,
expected_versioned_model_yml_increment_version_no_prerelease,
expected_versioned_model_yml_increment_version_with_prerelease,
expected_versioned_model_yml_no_version,
expected_versioned_model_yml_no_yml,
model_yml_increment_version,
model_yml_no_col_no_version,
model_yml_string_version,
shared_model_sql,
)

Expand Down Expand Up @@ -71,8 +74,22 @@ def reset_model_files(files_list):
["shared_model_v1.sql"],
[],
),
(
expected_versioned_model_yml_increment_version_with_prerelease,
expected_versioned_model_yml_increment_prerelease_version_with_second_prerelease,
["shared_model_v1.sql", "shared_model_v2.sql"],
["shared_model_v1.sql", "shared_model_v2.sql", "shared_model_v3.sql"],
["--prerelease"],
),
(
expected_versioned_model_yml_increment_version_with_prerelease,
expected_versioned_model_yml_increment_prerelease_version,
["shared_model_v1.sql", "shared_model_v2.sql"],
["shared_model_v1.sql", "shared_model_v2.sql", "shared_model_v3.sql"],
[],
),
],
ids=["1", "2", "3", "4", "5"],
ids=["1", "2", "3", "4", "5", "6", "7"],
)
def test_add_version_to_yml(start_yml, end_yml, start_files, expected_files, command_options):
yml_file = proj_path / "models" / "_models.yml"
Expand Down Expand Up @@ -103,3 +120,33 @@ def test_add_version_to_yml(start_yml, end_yml, start_files, expected_files, com
yml_file.unlink()
reset_model_files(["shared_model.sql"])
assert actual == yaml.safe_load(end_yml)


@pytest.mark.parametrize(
"start_yml,start_files",
[
(
model_yml_string_version,
["shared_model.sql"],
),
],
ids=["1"],
)
def test_add_version_to_invalid_yml(start_yml, start_files):
yml_file = proj_path / "models" / "_models.yml"
reset_model_files(start_files)
yml_file.parent.mkdir(parents=True, exist_ok=True)
runner = CliRunner()
# only create file if start_yml is not None
# in situations where models don't have a patch path, there isn't a yml file to read from
if start_yml:
yml_file.touch()
start_yml_content = yaml.safe_load(start_yml)
with open(yml_file, "w+") as f:
yaml.safe_dump(start_yml_content, f, sort_keys=False)
base_command = ["--select", "shared_model", "--project-path", proj_path_string]
result = runner.invoke(add_version, base_command, catch_exceptions=True)
assert result.exit_code == 1
# reset the read path to the default in the logic
yml_file.unlink()
reset_model_files(["shared_model.sql"])
31 changes: 31 additions & 0 deletions tests/sql_and_yml_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@
- v: 2
"""

expected_versioned_model_yml_increment_prerelease_version_with_second_prerelease = """
models:
- name: shared_model
latest_version: 1
description: "this is a test model"
versions:
- v: 1
- v: 2
- v: 3
"""

expected_versioned_model_yml_increment_prerelease_version = """
models:
- name: shared_model
latest_version: 2
description: "this is a test model"
versions:
- v: 1
- v: 2
- v: 3
"""

expected_versioned_model_yml_increment_version_defined_in = """
models:
- name: shared_model
Expand All @@ -203,6 +225,15 @@
defined_in: daves_model
"""

model_yml_string_version = """
models:
- name: shared_model
latest_version: john_olerud
description: "this is a test model"
versions:
- v: john_olerud
"""

# expected result when removing the shared_model entry from model_yml_no_col_no_version
expected_remove_model_yml__model_yml_no_col_no_version = """
name: shared_model
Expand Down

0 comments on commit 474617d

Please sign in to comment.