-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #138 from invoke-ai/model-merge
Model merging utility scripts
- Loading branch information
Showing
29 changed files
with
1,247 additions
and
368 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Model Merging | ||
|
||
`invoke-training` provides utility scripts for several common model merging workflows. This page contains a summary of the available tools. | ||
|
||
## `extract_lora_from_model_diff.py` | ||
|
||
Extract a LoRA model that represents the difference between two base models. | ||
|
||
Note that the extracted LoRA model is a lossy representation of the difference between the models, so some degradation in quality is expected. | ||
|
||
For usage docs, run: | ||
```bash | ||
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py -h | ||
``` | ||
|
||
## `merge_lora_into_model.py` | ||
|
||
Merge a LoRA model into a base model to produce a new base model. | ||
|
||
For usage docs, run: | ||
```bash | ||
python src/invoke_training/model_merge/scripts/merge_lora_into_model.py -h | ||
``` | ||
|
||
## `merge_models.py` | ||
|
||
Merge 2 or more base models to produce a single base model (using either LERP or SLERP). This is a simple merge strategy that merges all model weights in the same way. | ||
|
||
For usage docs, run: | ||
```bash | ||
python src/invoke_training/model_merge/scripts/merge_models.py -h | ||
``` | ||
|
||
## `merge_task_models_to_base_model.py` | ||
|
||
Merge 1 or more task-specific base models into a single starting base model (using either [TIES](https://arxiv.org/abs/2306.01708) or [DARE](https://arxiv.org/abs/2311.03099)). This merge strategy aims to preserve the task-specific behaviors of the task models while making only small changes to the original base model. This approach enables multiple task models to be merged without excessive interference between them. | ||
|
||
If you want to merge a task-specific LoRA into a base model using this strategy, first use `merge_lora_into_model.py` to produce a task-specific base model, then merge that new base model using this strategy. | ||
|
||
For usage docs, run: | ||
```bash | ||
python src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py -h | ||
``` |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import torch | ||
import tqdm | ||
from peft.peft_model import PeftModel | ||
|
||
# All original base model weights in a PeftModel have this prefix and suffix. | ||
PEFT_BASE_LAYER_PREFIX = "base_model.model." | ||
PEFT_BASE_LAYER_SUFFIX = ".base_layer.weight" | ||
|
||
|
||
def get_patched_base_weights_from_peft_model(peft_model: PeftModel) -> dict[str, torch.Tensor]: | ||
"""Get a state_dict containing the base model weights *thath are patched* in the provided PeftModel. I.e. only | ||
return base model weights that have associated LoRa layers, but don't return the LoRA layers. | ||
""" | ||
state_dict = peft_model.state_dict() | ||
out_state_dict: dict[str, torch.Tensor] = {} | ||
for weight_name in state_dict: | ||
# Weights that end with ".base_layer.weight" are the original weights for LoRA layers. | ||
if weight_name.endswith(PEFT_BASE_LAYER_SUFFIX): | ||
# Extract the base module name. | ||
module_name = weight_name[: -len(PEFT_BASE_LAYER_SUFFIX)] | ||
assert module_name.startswith(PEFT_BASE_LAYER_PREFIX) | ||
module_name = module_name[len(PEFT_BASE_LAYER_PREFIX) :] | ||
|
||
out_state_dict[module_name] = state_dict[weight_name] | ||
|
||
return out_state_dict | ||
|
||
|
||
def get_state_dict_diff( | ||
state_dict_1: dict[str, torch.Tensor], state_dict_2: dict[str, torch.Tensor] | ||
) -> dict[str, torch.Tensor]: | ||
"""Return the difference between two state_dicts: state_dict_1 - state_dict_2.""" | ||
return {key: state_dict_1[key] - state_dict_2[key] for key in state_dict_1} | ||
|
||
|
||
@torch.no_grad() | ||
def extract_lora_from_diffs( | ||
diffs: dict[str, torch.Tensor], rank: int, clamp_quantile: float, out_dtype: torch.dtype | ||
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: | ||
lora_weights = {} | ||
for lora_name, mat in tqdm.tqdm(list(diffs.items())): | ||
# Use full precision for the intermediate calculations. | ||
mat = mat.to(torch.float32) | ||
|
||
is_conv2d = False | ||
if len(mat.shape) == 4: # Conv2D | ||
is_conv2d = True | ||
out_dim, in_dim, kernel_h, kernel_w = mat.shape | ||
# Reshape to (out_dim, in_dim * kernel_h * kernel_w). | ||
mat = mat.flatten(start_dim=1) | ||
elif len(mat.shape) == 2: # Linear | ||
out_dim, in_dim = mat.shape | ||
else: | ||
raise ValueError(f"Unexpected weight shape: {mat.shape}") | ||
|
||
# LoRA rank cannot exceed the original dimensions. | ||
assert rank < in_dim | ||
assert rank < out_dim | ||
|
||
u: torch.Tensor | ||
s: torch.Tensor | ||
v_h: torch.Tensor | ||
u, s, v_h = torch.linalg.svd(mat) | ||
|
||
# Apply the Eckart-Young-Mirsky theorem. | ||
# https://en.wikipedia.org/wiki/Low-rank_approximation#Proof_of_Eckart%E2%80%93Young%E2%80%93Mirsky_theorem_(for_Frobenius_norm) | ||
u = u[:, :rank] | ||
s = s[:rank] | ||
u = u @ torch.diag(s) | ||
|
||
v_h = v_h[:rank, :] | ||
|
||
# At this point, u is the lora_up (a.k.a. lora_B) weight, and v_h is the lora_down (a.k.a. lora_A) weight. | ||
# The reason we don't use more appropriate variable names is to keep memory usage low - we want the old tensors | ||
# to get cleaned up after each operation. | ||
|
||
# Clamp the outliers. | ||
dist = torch.cat([u.flatten(), v_h.flatten()]) | ||
hi_val = torch.quantile(dist, clamp_quantile) | ||
low_val = -hi_val | ||
|
||
u = u.clamp(low_val, hi_val) | ||
v_h = v_h.clamp(low_val, hi_val) | ||
|
||
if is_conv2d: | ||
u = u.reshape(out_dim, rank, 1, 1) | ||
v_h = v_h.reshape(rank, in_dim, kernel_h, kernel_w) | ||
|
||
u = u.to(dtype=out_dtype).contiguous() | ||
v_h = v_h.to(dtype=out_dtype).contiguous() | ||
|
||
lora_weights[lora_name] = (u, v_h) | ||
return lora_weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Literal | ||
|
||
import torch | ||
import tqdm | ||
|
||
from invoke_training.model_merge.utils.normalize_weights import normalize_weights | ||
|
||
|
||
@torch.no_grad() | ||
def merge_models( | ||
state_dicts: list[dict[str, torch.Tensor]], weights: list[float], merge_method: Literal["LERP", "SLERP"] = "LERP" | ||
): | ||
"""Merge multiple models into a single model. | ||
Args: | ||
state_dicts (list[dict[str, torch.Tensor]]): The state dicts to merge. | ||
weights (list[float]): The weights for each state dict. The weights will be normalized to sum to 1. | ||
merge_method (Literal["LERP", "SLERP"]): Merge method to use. Options: | ||
- "LERP": Linear interpolation a.k.a. weighted sum. | ||
- "SLERP": Spherical linear interpolation. | ||
""" | ||
if len(state_dicts) < 2: | ||
raise ValueError("Must provide >=2 models to merge.") | ||
|
||
if len(state_dicts) != len(weights): | ||
raise ValueError("Must provide a weight for each model.") | ||
|
||
if merge_method == "LERP": | ||
merge_fn = lerp | ||
elif merge_method == "SLERP": | ||
merge_fn = slerp | ||
else: | ||
raise ValueError(f"Unknown merge method: {merge_method}") | ||
|
||
normalized_weights = normalize_weights(weights) | ||
|
||
out_state_dict: dict[str, torch.Tensor] = state_dicts[0].copy() | ||
out_state_dict_weight = normalized_weights[0] | ||
for state_dict, normalized_weight in zip(state_dicts[1:], normalized_weights[1:], strict=True): | ||
if state_dict.keys() != out_state_dict.keys(): | ||
raise ValueError("State dicts must have the same keys.") | ||
|
||
cur_pair_weights = normalize_weights([out_state_dict_weight, normalized_weight]) | ||
for key in tqdm.tqdm(out_state_dict.keys()): | ||
out_state_dict[key] = merge_fn(out_state_dict[key], state_dict[key], cur_pair_weights[0]) | ||
|
||
# Update the weight of out_state_dict to be the sum of all state dicts merged so far. | ||
out_state_dict_weight += normalized_weight | ||
|
||
return out_state_dict | ||
|
||
|
||
def lerp(a: torch.Tensor, b: torch.Tensor, weight_a: float) -> torch.Tensor: | ||
"""Linear interpolation.""" | ||
return torch.lerp(a, b, (1.0 - weight_a)) | ||
|
||
|
||
def slerp(a: torch.Tensor, b: torch.Tensor, weight_a: float, dot_product_thres=0.9995, epsilon=1e-10): | ||
"""Spherical linear interpolation.""" | ||
# TODO(ryand): For multi-dimensional matrices, it might be better to apply slerp on a subset of the dimensions | ||
# (e.g. per-row), rather than treating the entire matrix as a single flattened vector. | ||
|
||
# Normalize the vectors. | ||
a_norm = torch.linalg.norm(a) | ||
b_norm = torch.linalg.norm(b) | ||
a_normalized = a / a_norm | ||
b_normalized = b / b_norm | ||
|
||
if a_norm < epsilon or b_norm < epsilon: | ||
# If either vector is very small, fallback to lerp to avoid weird effects. | ||
# TODO(ryand): Is fallback here necessary? | ||
return lerp(a, b, weight_a) | ||
|
||
# Dot product of the normalized vectors. | ||
# We are effectively treating multi-dimensional tensors as flattened vectors. | ||
dot_prod = torch.sum(a_normalized * b_normalized) | ||
|
||
# If the absolute value of the dot product is almost 1, the vectors are ~colinear, so use lerp. | ||
if torch.abs(dot_prod) > dot_product_thres: | ||
return lerp(a, b, weight_a) | ||
|
||
# Calculate initial angle between the vectors. | ||
theta_0 = torch.acos(dot_prod) | ||
|
||
# Angle at timestep t. | ||
t = 1.0 - weight_a | ||
theta_t = theta_0 * t | ||
|
||
sin_theta_0 = torch.sin(theta_0) | ||
sin_theta_t = torch.sin(theta_t) | ||
|
||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 | ||
s1 = sin_theta_t / sin_theta_0 | ||
result = s0 * a + s1 * b | ||
|
||
return result |
Oops, something went wrong.