From 645eb3c6adcfa618bb5b975749944cbe52e6c1f7 Mon Sep 17 00:00:00 2001 From: CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Date: Wed, 16 Oct 2024 23:18:56 +0200 Subject: [PATCH] consolidated.safetensors easier handling (as eg for Ministral) --- convert_hf_to_gguf.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index da5feb25b1961..d58c1b62871d7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -83,10 +83,10 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file self.lazy = not eager - self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") + self.part_names = Model.get_model_part_names(self.dir_model, ["model"], [".safetensors"]) self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: - self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") + self.part_names = Model.get_model_part_names(self.dir_model, ["pytorch_model"], [".bin"]) self.hparams = Model.load_hparams(self.dir_model) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -447,10 +447,12 @@ def write_vocab(self): self.gguf_writer.close() @staticmethod - def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: + def get_model_part_names(dir_model: Path, prefixes: list[str], suffixes: list[str]) -> list[str]: part_names: list[str] = [] for filename in os.listdir(dir_model): - if filename.startswith(prefix) and filename.endswith(suffix): + if any(filename.startswith(prefix) for prefix in prefixes) and any(filename.endswith(suffix) for suffix in suffixes): + part_names.append(filename) + elif filename == "consolidated.safetensors": part_names.append(filename) part_names.sort()