Skip to content

Commit

Permalink
script to manually add weights from base model after merging submodul…
Browse files Browse the repository at this point in the history
…e. And minor fixes.
  • Loading branch information
ElliotStein committed Nov 19, 2024
1 parent d526eb9 commit f081a0b
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 1 deletion.
11 changes: 11 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:

def strip_prefix(name: str, prefixes: List[str]) -> str:
"""Remove any prefix in prefixes from the start of the name."""
prefixes = [prefixes] if isinstance(prefixes, str) else prefixes
for prefix in prefixes:
if name.startswith(prefix + "."):
return name[len(prefix) + 1 :]
Expand Down Expand Up @@ -609,6 +610,16 @@ def _infer_architecture_info(merge_config):
prefixes = find_prefixes_for_alignment(param_names)
common_names = find_common_ordered_names(param_names, prefixes)

if not common_names:
raise ValueError(
"Could not resolve model architecture automatically. No common parameter names found."
)

if len(common_names) != len(param_names[0]):
warnings.warn(
f"Merging {len(common_names)} common parameters, out of {len(param_names[0])} total. Run fill_missing_params.py script after merge."
)

prefix_tracker = {
model.model.path: f"{prefix}." if prefix else ""
for model, prefix in zip(referenced_models, prefixes)
Expand Down
4 changes: 3 additions & 1 deletion mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def _update_config_vocab(
)


def _load_arch_info(merge_config, options):
def _load_arch_info(
merge_config: MergeConfiguration, options: MergeOptions
) -> ArchitectureInfo:
"""
Loads architecture information, handling cases where models lack predefined architecture info.
"""
Expand Down
203 changes: 203 additions & 0 deletions mergekit/scripts/fill_missing_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import logging
import shutil
from pathlib import Path

import click
import torch
from safetensors import safe_open
from tqdm import tqdm

from mergekit.architecture import (
_get_model_parameter_names,
_resolve_model_directory,
find_common_ordered_names,
find_prefix_and_check_sublist,
strip_prefix,
)
from mergekit.io.lazy_tensor_loader import ShardedTensorIndex
from mergekit.io.tensor_writer import TensorWriter

DEFAULT_SHARD_SIZE = 5 * 1024**3


def load_tensor_from_file(tensor_name: str, tensor_file: str = None) -> torch.Tensor:
"""
Load a specific tensor from a .safetensors file.
:param tensor_name: The name of the tensor to load.
:param tensor_file: The .safetensors file that contains the tensor.
:return: The loaded tensor as a PyTorch tensor.
"""
with safe_open(tensor_file, framework="pt", device="cpu") as f:
if tensor_name in f.keys():
return f.get_tensor(tensor_name)
else:
raise ValueError(
f"Tensor '{tensor_name}' not found in file '{tensor_file}'"
)


def load_tensor_from_index(tensor_name: str, index: ShardedTensorIndex) -> torch.Tensor:
"""
Load a specific tensor from a ShardedTensorIndex.
:param tensor_name: The name of the tensor to load.
:param index: The ShardedTensorIndex containing the tensor.
:return: The loaded tensor as a PyTorch tensor.
"""
return load_tensor_from_file(
tensor_name, Path(index.base_path) / index.tensor_paths[tensor_name]
)


def copy_and_fill_missing_params(
base_model_repo_id: str,
sub_model_dir: str,
max_shard_size: int = DEFAULT_SHARD_SIZE,
output_dir: str = None,
):
"""
Merge submodel weights into a base model and fill in missing parameters.
Use Case:
Given a submodel (e.g., a language model) that is structurally identical to a subset of a
larger base model (e.g., a vision-language model).
The submodel contains only a subset of the weights (e.g., for the language model part),
while the base model contains all weights required for the complete architecture.
This function replaces the shared parameters in the base model with those from the submodel,
fascilitating testing after generating submodel parameters through merging.
Parameters:
base_model_repo_id (str):
The path to the base model's directory or its Hugging Face repository ID.
This model provides all parameters and files required for the complete model.
sub_model_dir (str):
The path to the submodel's directory containing the merged weights.
Parameters in this directory replace the corresponding weights in the base model.
max_shard_size (int, optional):
The maximum shard size for saving model weights, in bytes. Defaults to 5 GiB.
output_dir (str, optional):
The directory to save the final merged model. If not provided, a default directory
is created using the names of the base and submodel.
Returns:
pathlib.Path:
The path to the directory where the final merged model is saved.
Raises:
AssertionError:
If the base model has fewer parameters than the submodel, ensuring compatibility.
ValueError:
If tensor loading or parameter alignment issues occur.
Notes:
- The function does not modify the original base or submodel directories.
- For Hugging Face repository IDs, ensure the `HF_HOME` environment variable is properly configured.
- Non-shared parameters, as well as any additional configuration files, are copied from the base model to create a fully functional model.
"""
# Prepare paths and configurations
output_dir = (
Path(sub_model_dir).parent
/ f"{Path(base_model_repo_id).stem}--{Path(sub_model_dir).stem}"
if output_dir is None
else Path(output_dir)
)
output_dir.mkdir(parents=True, exist_ok=True)

# Resolve the model directory for the base model
base_dir = _resolve_model_directory(base_model_repo_id)
files_to_copy = [
item
for item in base_dir.rglob("*")
if item.is_file() and item.suffix not in {".safetensors", ".bin"}
]

# Copy non-parameter files from the base model
with tqdm(
total=len(files_to_copy), desc="Copying non-parameter files", unit="file"
) as pbar:
for item in files_to_copy:
target_path = output_dir / item.relative_to(base_dir)
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(item, target_path)
pbar.update(1)

# Retrieve parameter names from both models
base_param_names = _get_model_parameter_names(base_model_repo_id)
submodel_param_names = _get_model_parameter_names(sub_model_dir)

# Ensure the base model has more parameters than the submodel
assert len(base_param_names) > len(submodel_param_names), (
f"Base model must have more parameters than the submodel. "
f"Base: {len(base_param_names)}, Submodel: {len(submodel_param_names)}"
)

# Determine parameter prefix and find common names
prefix = find_prefix_and_check_sublist(base_param_names, submodel_param_names)
common_param_names = find_common_ordered_names(
[base_param_names, submodel_param_names], ["", prefix]
)

# Load parameter indices for tensor storage
base_index = ShardedTensorIndex.from_disk(str(base_dir))
submodel_index = ShardedTensorIndex.from_disk(
str(_resolve_model_directory(sub_model_dir))
)

# Initialize the tensor writer
writer = TensorWriter(
out_path=str(output_dir), max_shard_size=max_shard_size, safe_serialization=True
)

# Copy and fill parameters from base to submodel
for name, tensor_path in tqdm(
base_index.tensor_paths.items(),
total=len(base_index.tensor_paths),
desc="Merging tensors",
unit="tensor",
):
tensor = load_tensor_from_index(name, base_index)

# Check if the parameter is common to both models
if name in common_param_names:
submodel_name = strip_prefix(name, prefix)
submodel_tensor = load_tensor_from_index(submodel_name, submodel_index)

# Log size mismatches
if submodel_tensor.size() != tensor.size():
logging.warning(
f"Size mismatch for tensor '{name}': {tensor.size()} vs {submodel_tensor.size()}"
)

tensor = submodel_tensor

# Save the tensor to the output directory
writer.save_tensor(name, tensor.clone())

# Finalize the writer to ensure data is saved and index file is created
writer.finalize()

return output_dir


@click.command()
@click.argument("base_model_repo_id", type=str)
@click.argument("sub_model_dir", type=str)
@click.option("--max_shard_size", type=int, default=DEFAULT_SHARD_SIZE)
@click.option("--output_dir", type=str, default=None)
def main(
base_model_repo_id,
sub_model_dir,
max_shard_size,
output_dir,
):
copy_and_fill_missing_params(
base_model_repo_id, sub_model_dir, max_shard_size, output_dir
)


if __name__ == "__main__":
main()

0 comments on commit f081a0b

Please sign in to comment.