From 31d71c071cc5162dc474a3e2a2373a300cf70a28 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 3 Jun 2024 10:19:16 -0400 Subject: [PATCH] Move parse_model_arg(...) to a util file. --- .../model_merge/scripts/merge_models.py | 12 +----------- .../model_merge/utils/parse_model_arg.py | 9 +++++++++ 2 files changed, 10 insertions(+), 11 deletions(-) create mode 100644 src/invoke_training/model_merge/utils/parse_model_arg.py diff --git a/src/invoke_training/model_merge/scripts/merge_models.py b/src/invoke_training/model_merge/scripts/merge_models.py index 7d48b8d8..ca46cfc6 100644 --- a/src/invoke_training/model_merge/scripts/merge_models.py +++ b/src/invoke_training/model_merge/scripts/merge_models.py @@ -9,6 +9,7 @@ from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline from invoke_training.model_merge.merge_models import merge_models +from invoke_training.model_merge.utils import parse_model_arg @dataclass @@ -69,17 +70,6 @@ def run_merge_models( logger.info(f"Saved merged model to '{out_dir_path}'.") -def parse_model_arg(model: str) -> tuple[str, str | None]: - """Parse a --models argument into a model and a variant.""" - parts = model.split("::") - if len(parts) == 1: - return parts[0], None - elif len(parts) == 2: - return parts[0], parts[1] - else: - raise ValueError(f"Unexpected format for --models arg: '{model}'.") - - def parse_model_args(models: list[str], weights: list[str]) -> list[MergeModel]: """Parse a list of --models arguments and --weights arguments into a list of MergeModels.""" merge_model_list: list[MergeModel] = [] diff --git a/src/invoke_training/model_merge/utils/parse_model_arg.py b/src/invoke_training/model_merge/utils/parse_model_arg.py new file mode 100644 index 00000000..c061a6a5 --- /dev/null +++ b/src/invoke_training/model_merge/utils/parse_model_arg.py @@ -0,0 +1,9 @@ +def parse_model_arg(model: str, delimiter: str = "::") -> tuple[str, str | None]: + """Parse a model argument into a model and a variant.""" + parts = model.split(delimiter) + if len(parts) == 1: + return parts[0], None + elif len(parts) == 2: + return parts[0], parts[1] + else: + raise ValueError(f"Unexpected format for --models arg: '{model}'.")