diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 7e19ef31..5d1004f0 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -19,13 +19,17 @@ import warnings from abc import ABC, abstractmethod from collections import defaultdict +from pathlib import Path from typing import ClassVar, Dict, List, Optional, Tuple, Union +from huggingface_hub import snapshot_download +from huggingface_hub.utils import HfHubHTTPError from pydantic import BaseModel, Field from transformers import PretrainedConfig from typing_extensions import Literal import mergekit._data.architectures +from mergekit.io.lazy_tensor_loader import ShardedTensorIndex class WeightInfo(BaseModel, frozen=True): @@ -228,42 +232,83 @@ def _hierarchy(names, layer_prefix=r"\.\d+\.") -> Dict[str, List[str]]: class AutomaticArchitectureInfo(ArchitectureInfo, BaseModel): arch_name: str = Field(default="") parameter_names: List[str] = Field(default_factory=list) + embed: List[str] = Field(default_factory=list) layered_parameter_names: Dict[str, List[str]] = Field(default_factory=dict) + prefix_tracker: Dict[str, str] = Field(default_factory=dict) - def __init__(self, arch_name: str, parameter_names: List[str]): + def __init__( + self, + arch_name: str, + parameter_names: List[str], + prefix_tracker: Optional[Dict[str, str]] = None, + ): super().__init__() - self.arch_name = arch_name self.parameter_names = parameter_names self.layered_parameter_names = _hierarchy(self.parameter_names) + self.prefix_tracker = prefix_tracker or {} + self.embed = self._find_embed_params() + + def _find_embed_params(self) -> List[str]: + """Identify embedding parameters (e.g., 'lm_head', 'embed') that may require special handling.""" + embed_params = [] + for name in self.parameter_names: + if any(embedding_name in name for embedding_name in ["lm_head", "embed"]): + embed_params.append(name) + return embed_params def name(self) -> str: + """Returns the architecture name.""" return self.arch_name def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - # AutomaticArchitectureInfo places all parameters into layer_weights, rather than pre/post weights - # Since many models do not have a clear distinction between pre/post weights + """This architecture does not distinguish pre-weights.""" return [] def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - # AutomaticArchitectureInfo places all parameters into layer_weights, rather than pre/post weights - # Since many models do not have a clear distinction between pre/post weights + """This architecture does not distinguish post-weights.""" return [] def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: + """ + Retrieves the weights for a specified layer, adjusting names for prefixes if applicable. + """ layer_name = list(self.layered_parameter_names.keys())[index] - return [ - WeightInfo(name=(layer_name + ("." + param if param else ""))) + adjusted_layer_name = self._adjust_layer_name(layer_name, config) + + weights = [ + WeightInfo( + name=f"{adjusted_layer_name}.{param}" if param else adjusted_layer_name, + is_embed=(layer_name in self.embed), + ) for param in self.layered_parameter_names[layer_name] ] + return ( + weights + if weights + else [ + WeightInfo( + name=adjusted_layer_name, is_embed=(layer_name in self.embed) + ) + ] + ) + + def _adjust_layer_name(self, layer_name: str, config: PretrainedConfig) -> str: + """Adjust layer names by removing any prefix as indicated in the prefix tracker.""" + if config and config.name_or_path in self.prefix_tracker: + prefix = self.prefix_tracker.get(config.name_or_path, "") + if layer_name.startswith(prefix): + return layer_name[len(prefix) :] + return layer_name def sliceable(self) -> bool: + """Indicates if the architecture supports slicing.""" return True def num_layers(self, config: PretrainedConfig) -> int: - # Note lack of pre/post weights distinction means 'model.layer.i' may not correspond to the ith layer + """Returns the number of layers based on layered parameter names.""" return len(self.layered_parameter_names) @@ -450,3 +495,134 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: f"Unsupported model_type {config.model_type} for architecture {arch_name}" ) return False + + +def strip_prefix(name: str, prefixes: List[str]) -> str: + """Remove any prefix in prefixes from the start of the name.""" + for prefix in prefixes: + if name.startswith(prefix + "."): + return name[len(prefix) + 1 :] + return name + + +def is_ordered_sublist_with_prefix( + list1: List[str], list2: List[str], prefixes: List[str] +) -> bool: + """ + Check if list1 matches a subset of list2 in the correct order after optional prefix removal. + """ + stripped_list2 = [strip_prefix(name, prefixes) for name in list2] + + try: + start_index = stripped_list2.index(list1[0]) + for i, item in enumerate(list1): + if stripped_list2[start_index + i] != item: + return False + return True + except (ValueError, IndexError): + return False + + +def find_prefix_and_check_sublist(list1: List[str], list2: List[str]) -> Optional[str]: + """ + Attempts to find a prefix from elements in list2 that makes list1 an ordered sublist of list2. + """ + if len(list1) > len(list2): + list1, list2 = list2, list1 + + possible_prefixes = {item.split(".")[0] for item in list2 if "." in item} + + for prefix in possible_prefixes: + if is_ordered_sublist_with_prefix(list1, list2, [prefix]): + return prefix + + return None + + +def find_prefixes_for_alignment(param_names: List[List[str]]) -> List[str]: + """Determine prefixes needed to align parameter names in order of the longest list.""" + prefixes = [""] + for i in range(1, len(param_names)): + if param_names[0] != param_names[i]: + prefix = find_prefix_and_check_sublist(param_names[0], param_names[i]) + if not prefix: + raise ValueError("Could not resolve model architecture automatically.") + else: + prefix = "" + prefixes.append(prefix) + return prefixes + + +def find_common_ordered_names( + param_names: List[List[str]], prefixes: List[str] +) -> List[str]: + """Identify and return common parameter names across all models, ensuring correct order.""" + common_names = set(param_names[0]) + for i in range(1, len(param_names)): + prefix = f"{prefixes[i]}." if prefixes[i] else "" + common_names.intersection_update({prefix + name for name in param_names[i]}) + return [name for name in param_names[0] if name in common_names] + + +def _get_model_parameter_names(repo_id: str) -> list: + """ + Get the parameter names of a model from a Hugging Face repo or local directory. + """ + model_dir = _resolve_model_directory(repo_id) + return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys()) + + + +def _resolve_model_directory(repo_id: str) -> Path: + """ + Resolve the model directory either from a local path, URL, or by downloading from Hugging Face. + """ + if Path(repo_id).is_dir(): + return Path(repo_id) + + try: + return Path(snapshot_download(repo_id)) + except HfHubHTTPError: + raise ValueError(f"Model {repo_id} not found on Hugging Face Hub.") + except Exception as e: + raise ValueError(f"Error locating model {repo_id}: {e}") + + +def _infer_architecture_info(merge_config): + """ + Infers and returns architecture info, including parameter names and prefixes for alignment. + """ + param_names = [ + _get_model_parameter_names(source_model.model.path) + for source_model in merge_config.referenced_models() + ] + + if all(param_names[0] == param_names[i] for i in range(1, len(param_names))): + arch_name = merge_config.referenced_models()[0].model.path + parameter_names = param_names[0] + prefix_tracker = {} + else: + + # Pair param_names with referenced models and sort by length + paired_list = list(zip(param_names, merge_config.referenced_models())) + paired_list.sort(key=lambda x: len(x[0]), reverse=True) + param_names, referenced_models = zip(*paired_list) + + prefixes = find_prefixes_for_alignment(param_names) + common_names = find_common_ordered_names(param_names, prefixes) + + prefix_tracker = { + model.model.path: f"{prefix}." if prefix else "" + for model, prefix in zip(referenced_models, prefixes) + } + + arch_name = referenced_models[0].model.path + parameter_names = common_names + + return [ + AutomaticArchitectureInfo( + arch_name=arch_name, + parameter_names=parameter_names, + prefix_tracker=prefix_tracker, + ) + ] diff --git a/mergekit/merge.py b/mergekit/merge.py index 7ae74bb5..7a804dd6 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -20,7 +20,7 @@ import shutil from collections import Counter from pathlib import Path -from typing import Optional +from typing import List, Optional import tqdm import transformers @@ -31,6 +31,7 @@ from mergekit.architecture import ( ArchitectureInfo, AutomaticArchitectureInfo, + _infer_architecture_info, get_architecture_info, ) from mergekit.card import generate_card @@ -276,75 +277,25 @@ def _update_config_vocab( def _load_arch_info(merge_config, options): + """ + Loads architecture information, handling cases where models lack predefined architecture info. + """ model_arch_info = [ get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) for m in merge_config.referenced_models() ] - # Check if any of the models failed to load architecture info - if any(a is False for a in model_arch_info): - # Attempt to load the architecture automatically - model_arch_info = [ - AutomaticArchitectureInfo( - arch_name=source_model.model.path, - parameter_names=_get_model_parameter_names(source_model.model.path), - ) - for source_model in merge_config.referenced_models() - ] - if not all( - a.all_weights(None) == model_arch_info[0].all_weights(None) - for a in model_arch_info[1:] - ): - raise RuntimeError( - "AutomaticArchitectureInfo only supports models with the same architecture" - ) - else: + if not any(a is False for a in model_arch_info): if not options.allow_crimes and not all( a == model_arch_info[0] for a in model_arch_info[1:] ): raise RuntimeError( "Must specify --allow-crimes to attempt to mix different architectures" ) + else: + model_arch_info = _infer_architecture_info(merge_config) return model_arch_info[0] -def _get_model_parameter_names(repo_id: str) -> list: - """ - Get the parameter names of a model from a Hugging Face repo or local directory. - - This function checks if the model is available locally or in the Hugging Face cache. - If the model is not available, it attempts to download it. If the download fails, - it raises an error. Once the model is resolved, it returns the list of tensor paths. - - :param repo_id: The model's repo ID, URL, or local directory path. - :return: A list of parameter names. - """ - # Try to resolve the model directory, either locally or by downloading - model_dir = _resolve_model_directory(repo_id) - - # Attempt to get the tensor paths from the resolved directory - return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys()) - - -def _resolve_model_directory(repo_id: str) -> Path: - """ - Resolve the model directory either from a local path, URL, or by downloading from Hugging Face. - - :param repo_id: The model's repo ID, URL, or local directory path. - :return: The path to the resolved model directory. - """ - if Path(repo_id).is_dir(): - # If it's a local directory, return the path - return Path(repo_id) - - try: - # Use Hugging Face snapshot_download to check cache or download the model - return Path(snapshot_download(repo_id)) - except HfHubHTTPError: - raise ValueError(f"Model {repo_id} not found on Hugging Face Hub.") - except Exception as e: - raise ValueError(f"Error locating model {repo_id}: {e}") - - __all__ = ["MergeOptions", "run_merge"]