Skip to content

Commit

Permalink
Enable autodetection and merging for submodules. If parameter names m…
Browse files Browse the repository at this point in the history
…atch, or match when a prefix is removed (e.g. vision_block.layer.0 and layer.0), their overlapping layers can now be merged.
  • Loading branch information
ElliotStein committed Nov 4, 2024
1 parent dcf8c31 commit 5b96b97
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 66 deletions.
194 changes: 185 additions & 9 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import ClassVar, Dict, List, Optional, Tuple, Union

from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError
from pydantic import BaseModel, Field
from transformers import PretrainedConfig
from typing_extensions import Literal

import mergekit._data.architectures
from mergekit.io.lazy_tensor_loader import ShardedTensorIndex


class WeightInfo(BaseModel, frozen=True):
Expand Down Expand Up @@ -228,42 +232,83 @@ def _hierarchy(names, layer_prefix=r"\.\d+\.") -> Dict[str, List[str]]:
class AutomaticArchitectureInfo(ArchitectureInfo, BaseModel):
arch_name: str = Field(default="")
parameter_names: List[str] = Field(default_factory=list)
embed: List[str] = Field(default_factory=list)
layered_parameter_names: Dict[str, List[str]] = Field(default_factory=dict)
prefix_tracker: Dict[str, str] = Field(default_factory=dict)

def __init__(self, arch_name: str, parameter_names: List[str]):
def __init__(
self,
arch_name: str,
parameter_names: List[str],
prefix_tracker: Optional[Dict[str, str]] = None,
):
super().__init__()

self.arch_name = arch_name
self.parameter_names = parameter_names
self.layered_parameter_names = _hierarchy(self.parameter_names)
self.prefix_tracker = prefix_tracker or {}
self.embed = self._find_embed_params()

def _find_embed_params(self) -> List[str]:
"""Identify embedding parameters (e.g., 'lm_head', 'embed') that may require special handling."""
embed_params = []
for name in self.parameter_names:
if any(embedding_name in name for embedding_name in ["lm_head", "embed"]):
embed_params.append(name)
return embed_params

def name(self) -> str:
"""Returns the architecture name."""
return self.arch_name

def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
# AutomaticArchitectureInfo places all parameters into layer_weights, rather than pre/post weights
# Since many models do not have a clear distinction between pre/post weights
"""This architecture does not distinguish pre-weights."""
return []

def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
# AutomaticArchitectureInfo places all parameters into layer_weights, rather than pre/post weights
# Since many models do not have a clear distinction between pre/post weights
"""This architecture does not distinguish post-weights."""
return []

def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
"""
Retrieves the weights for a specified layer, adjusting names for prefixes if applicable.
"""
layer_name = list(self.layered_parameter_names.keys())[index]
return [
WeightInfo(name=(layer_name + ("." + param if param else "")))
adjusted_layer_name = self._adjust_layer_name(layer_name, config)

weights = [
WeightInfo(
name=f"{adjusted_layer_name}.{param}" if param else adjusted_layer_name,
is_embed=(layer_name in self.embed),
)
for param in self.layered_parameter_names[layer_name]
]
return (
weights
if weights
else [
WeightInfo(
name=adjusted_layer_name, is_embed=(layer_name in self.embed)
)
]
)

def _adjust_layer_name(self, layer_name: str, config: PretrainedConfig) -> str:
"""Adjust layer names by removing any prefix as indicated in the prefix tracker."""
if config and config.name_or_path in self.prefix_tracker:
prefix = self.prefix_tracker.get(config.name_or_path, "")
if layer_name.startswith(prefix):
return layer_name[len(prefix) :]
return layer_name

def sliceable(self) -> bool:
"""Indicates if the architecture supports slicing."""
return True

def num_layers(self, config: PretrainedConfig) -> int:
# Note lack of pre/post weights distinction means 'model.layer.i' may not correspond to the ith layer
"""Returns the number of layers based on layered parameter names."""
return len(self.layered_parameter_names)


Expand Down Expand Up @@ -450,3 +495,134 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:
f"Unsupported model_type {config.model_type} for architecture {arch_name}"
)
return False


def strip_prefix(name: str, prefixes: List[str]) -> str:
"""Remove any prefix in prefixes from the start of the name."""
for prefix in prefixes:
if name.startswith(prefix + "."):
return name[len(prefix) + 1 :]
return name


def is_ordered_sublist_with_prefix(
list1: List[str], list2: List[str], prefixes: List[str]
) -> bool:
"""
Check if list1 matches a subset of list2 in the correct order after optional prefix removal.
"""
stripped_list2 = [strip_prefix(name, prefixes) for name in list2]

try:
start_index = stripped_list2.index(list1[0])
for i, item in enumerate(list1):
if stripped_list2[start_index + i] != item:
return False
return True
except (ValueError, IndexError):
return False


def find_prefix_and_check_sublist(list1: List[str], list2: List[str]) -> Optional[str]:
"""
Attempts to find a prefix from elements in list2 that makes list1 an ordered sublist of list2.
"""
if len(list1) > len(list2):
list1, list2 = list2, list1

possible_prefixes = {item.split(".")[0] for item in list2 if "." in item}

for prefix in possible_prefixes:
if is_ordered_sublist_with_prefix(list1, list2, [prefix]):
return prefix

return None


def find_prefixes_for_alignment(param_names: List[List[str]]) -> List[str]:
"""Determine prefixes needed to align parameter names in order of the longest list."""
prefixes = [""]
for i in range(1, len(param_names)):
if param_names[0] != param_names[i]:
prefix = find_prefix_and_check_sublist(param_names[0], param_names[i])
if not prefix:
raise ValueError("Could not resolve model architecture automatically.")
else:
prefix = ""
prefixes.append(prefix)
return prefixes


def find_common_ordered_names(
param_names: List[List[str]], prefixes: List[str]
) -> List[str]:
"""Identify and return common parameter names across all models, ensuring correct order."""
common_names = set(param_names[0])
for i in range(1, len(param_names)):
prefix = f"{prefixes[i]}." if prefixes[i] else ""
common_names.intersection_update({prefix + name for name in param_names[i]})
return [name for name in param_names[0] if name in common_names]


def _get_model_parameter_names(repo_id: str) -> list:
"""
Get the parameter names of a model from a Hugging Face repo or local directory.
"""
model_dir = _resolve_model_directory(repo_id)
return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys())



def _resolve_model_directory(repo_id: str) -> Path:
"""
Resolve the model directory either from a local path, URL, or by downloading from Hugging Face.
"""
if Path(repo_id).is_dir():
return Path(repo_id)

try:
return Path(snapshot_download(repo_id))
except HfHubHTTPError:
raise ValueError(f"Model {repo_id} not found on Hugging Face Hub.")
except Exception as e:
raise ValueError(f"Error locating model {repo_id}: {e}")


def _infer_architecture_info(merge_config):
"""
Infers and returns architecture info, including parameter names and prefixes for alignment.
"""
param_names = [
_get_model_parameter_names(source_model.model.path)
for source_model in merge_config.referenced_models()
]

if all(param_names[0] == param_names[i] for i in range(1, len(param_names))):
arch_name = merge_config.referenced_models()[0].model.path
parameter_names = param_names[0]
prefix_tracker = {}
else:

# Pair param_names with referenced models and sort by length
paired_list = list(zip(param_names, merge_config.referenced_models()))
paired_list.sort(key=lambda x: len(x[0]), reverse=True)
param_names, referenced_models = zip(*paired_list)

prefixes = find_prefixes_for_alignment(param_names)
common_names = find_common_ordered_names(param_names, prefixes)

prefix_tracker = {
model.model.path: f"{prefix}." if prefix else ""
for model, prefix in zip(referenced_models, prefixes)
}

arch_name = referenced_models[0].model.path
parameter_names = common_names

return [
AutomaticArchitectureInfo(
arch_name=arch_name,
parameter_names=parameter_names,
prefix_tracker=prefix_tracker,
)
]
65 changes: 8 additions & 57 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import shutil
from collections import Counter
from pathlib import Path
from typing import Optional
from typing import List, Optional

import tqdm
import transformers
Expand All @@ -31,6 +31,7 @@
from mergekit.architecture import (
ArchitectureInfo,
AutomaticArchitectureInfo,
_infer_architecture_info,
get_architecture_info,
)
from mergekit.card import generate_card
Expand Down Expand Up @@ -276,75 +277,25 @@ def _update_config_vocab(


def _load_arch_info(merge_config, options):
"""
Loads architecture information, handling cases where models lack predefined architecture info.
"""
model_arch_info = [
get_architecture_info(m.config(trust_remote_code=options.trust_remote_code))
for m in merge_config.referenced_models()
]

# Check if any of the models failed to load architecture info
if any(a is False for a in model_arch_info):
# Attempt to load the architecture automatically
model_arch_info = [
AutomaticArchitectureInfo(
arch_name=source_model.model.path,
parameter_names=_get_model_parameter_names(source_model.model.path),
)
for source_model in merge_config.referenced_models()
]
if not all(
a.all_weights(None) == model_arch_info[0].all_weights(None)
for a in model_arch_info[1:]
):
raise RuntimeError(
"AutomaticArchitectureInfo only supports models with the same architecture"
)
else:
if not any(a is False for a in model_arch_info):
if not options.allow_crimes and not all(
a == model_arch_info[0] for a in model_arch_info[1:]
):
raise RuntimeError(
"Must specify --allow-crimes to attempt to mix different architectures"
)
else:
model_arch_info = _infer_architecture_info(merge_config)

return model_arch_info[0]


def _get_model_parameter_names(repo_id: str) -> list:
"""
Get the parameter names of a model from a Hugging Face repo or local directory.
This function checks if the model is available locally or in the Hugging Face cache.
If the model is not available, it attempts to download it. If the download fails,
it raises an error. Once the model is resolved, it returns the list of tensor paths.
:param repo_id: The model's repo ID, URL, or local directory path.
:return: A list of parameter names.
"""
# Try to resolve the model directory, either locally or by downloading
model_dir = _resolve_model_directory(repo_id)

# Attempt to get the tensor paths from the resolved directory
return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys())


def _resolve_model_directory(repo_id: str) -> Path:
"""
Resolve the model directory either from a local path, URL, or by downloading from Hugging Face.
:param repo_id: The model's repo ID, URL, or local directory path.
:return: The path to the resolved model directory.
"""
if Path(repo_id).is_dir():
# If it's a local directory, return the path
return Path(repo_id)

try:
# Use Hugging Face snapshot_download to check cache or download the model
return Path(snapshot_download(repo_id))
except HfHubHTTPError:
raise ValueError(f"Model {repo_id} not found on Hugging Face Hub.")
except Exception as e:
raise ValueError(f"Error locating model {repo_id}: {e}")


__all__ = ["MergeOptions", "run_merge"]

0 comments on commit 5b96b97

Please sign in to comment.