Skip to content

Commit

Permalink
Merge pull request #138 from invoke-ai/model-merge
Browse files Browse the repository at this point in the history
Model merging utility scripts
  • Loading branch information
RyanJDick authored Jun 4, 2024
2 parents 38242a8 + 81da84e commit 0b44077
Show file tree
Hide file tree
Showing 29 changed files with 1,247 additions and 368 deletions.
1 change: 0 additions & 1 deletion docs/concepts/index.md

This file was deleted.

2 changes: 1 addition & 1 deletion docs/get-started/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ Follow the [`invoke-training` installation instructions](./installation.md).

### 2. Training

See the [Textual Inversion - SDXL](../tutorials/stable_diffusion/textual_inversion_sdxl.md) tutorial for instructions on how to train a model via the CLI.
See the [Textual Inversion - SDXL](../guides/stable_diffusion/textual_inversion_sdxl.md) tutorial for instructions on how to train a model via the CLI.
File renamed without changes.
43 changes: 43 additions & 0 deletions docs/guides/model_merge.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Model Merging

`invoke-training` provides utility scripts for several common model merging workflows. This page contains a summary of the available tools.

## `extract_lora_from_model_diff.py`

Extract a LoRA model that represents the difference between two base models.

Note that the extracted LoRA model is a lossy representation of the difference between the models, so some degradation in quality is expected.

For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py -h
```

## `merge_lora_into_model.py`

Merge a LoRA model into a base model to produce a new base model.

For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_lora_into_model.py -h
```

## `merge_models.py`

Merge 2 or more base models to produce a single base model (using either LERP or SLERP). This is a simple merge strategy that merges all model weights in the same way.

For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_models.py -h
```

## `merge_task_models_to_base_model.py`

Merge 1 or more task-specific base models into a single starting base model (using either [TIES](https://arxiv.org/abs/2306.01708) or [DARE](https://arxiv.org/abs/2311.03099)). This merge strategy aims to preserve the task-specific behaviors of the task models while making only small changes to the original base model. This approach enables multiple task models to be merged without excessive interference between them.

If you want to merge a task-specific LoRA into a base model using this strategy, first use `merge_lora_into_model.py` to produce a task-specific base model, then merge that new base model using this strategy.

For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py -h
```
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ An alternative to using the finetuned UNet model directly is to compare it again

To extract a LoRA model, run the following command:
```bash
python src/invoke_training/scripts/_experimental/lora_extraction/extract_lora_from_checkpoint.py \
--model-type sdxl \
--model-orig path/to/stable-diffusion-xl-base-1.0/unet \
--model-tuned output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000/unet \
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py \
--model-type SDXL \
--model-orig path/to/stable-diffusion-xl-base-1.0 \
--model-tuned output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000 \
--save-to robocats_lora_step_2000.safetensors \
--lora-rank 32
```
Expand Down
3 changes: 1 addition & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ A library for training custom Stable Diffusion models (fine-tuning, LoRA trainin
The documentation is organized as follows:

- [Get Started](get-started/installation.md): Install `invoke-training` and run your first training pipeline.
- [Tutorials](tutorials/index.md): Full tutorials for running popular training pipelines.
- [Concepts](concepts/index.md): General concepts for `invoke-training` users.
- [Guides](guides/dataset_formats.md): Full tutorials for running popular training pipelines.
- [Config Reference](reference/config/index.md): Reference documentation for all supported training configuration options.
- [Contributing](contributing/development_environment.md): Information for `invoke-training` developers.
1 change: 0 additions & 1 deletion docs/tutorials/index.md

This file was deleted.

18 changes: 8 additions & 10 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,14 @@ nav:
- Get Started:
- get-started/installation.md
- get-started/quick-start.md
- Tutorials:
- tutorials/index.md
- Stable Diffusion:
- tutorials/stable_diffusion/robocats_finetune_sdxl.md
- tutorials/stable_diffusion/gnome_lora_masks_sdxl.md
- tutorials/stable_diffusion/textual_inversion_sdxl.md
- tutorials/stable_diffusion/dpo_lora_sd.md
- Concepts:
- concepts/index.md
- concepts/dataset_formats.md
- Guides:
- Dataset Formats: guides/dataset_formats.md
- Model Merging: guides/model_merge.md
- Stable Diffusion Training:
- guides/stable_diffusion/robocats_finetune_sdxl.md
- guides/stable_diffusion/gnome_lora_masks_sdxl.md
- guides/stable_diffusion/textual_inversion_sdxl.md
- guides/stable_diffusion/dpo_lora_sd.md
- YAML Config Reference:
- reference/config/index.md
- pipelines:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"invokeai==4.2.3",
"numpy",
"omegaconf",
"peft~=0.7.0",
"peft~=0.11.1",
"Pillow",
"prodigyopt",
"pydantic",
Expand Down
52 changes: 39 additions & 13 deletions src/invoke_training/_shared/stable_diffusion/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ class PipelineVersionEnum(Enum):


def load_pipeline(
logger: logging.Logger, model_name_or_path: str, pipeline_version: PipelineVersionEnum, variant: str | None = None
logger: logging.Logger,
model_name_or_path: str,
pipeline_version: PipelineVersionEnum,
torch_dtype: torch.dtype = None,
variant: str | None = None,
) -> typing.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]:
"""Load a Stable Diffusion pipeline from disk.
Expand All @@ -46,36 +50,58 @@ def load_pipeline(
raise ValueError(f"Unsupported pipeline_version: '{pipeline_version}'.")

if os.path.isfile(model_name_or_path):
return pipeline_class.from_single_file(model_name_or_path, load_safety_checker=False)
return pipeline_class.from_single_file(model_name_or_path, torch_dtype=torch_dtype, load_safety_checker=False)

return from_pretrained_with_variant_fallback(
logger=logger,
model_class=pipeline_class,
model_name_or_path=model_name_or_path,
torch_dtype=torch_dtype,
variant=variant,
# kwargs
safety_checker=None,
requires_safety_checker=False,
)


ModelT = typing.TypeVar("ModelT")


def from_pretrained_with_variant_fallback(
logger: logging.Logger,
model_class: typing.Type[ModelT],
model_name_or_path: str,
torch_dtype: torch.dtype | None = None,
variant: str | None = None,
**kwargs,
) -> ModelT:
"""A wrapper for .from_pretrained() that tries multiple variants if the initial one fails."""
variants_to_try = [variant] + [v for v in HF_VARIANT_FALLBACKS if v != variant]

pipeline = None
model: ModelT | None = None
for variant_to_try in variants_to_try:
if variant_to_try != variant:
logger.warning(f"Trying fallback variant '{variant_to_try}'.")
try:
pipeline = pipeline_class.from_pretrained(
model = model_class.from_pretrained(
model_name_or_path,
safety_checker=None,
torch_dtype=torch_dtype,
variant=variant_to_try,
requires_safety_checker=False,
**kwargs,
)
except OSError as e:
if "no file named" in str(e):
# Ok; we'll try the variant fallbacks.
logger.warning(
f"Failed to load pipeline '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}."
)
logger.warning(f"Failed to load '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}.")
else:
raise

if pipeline is not None:
if model is not None:
break

if pipeline is None:
raise RuntimeError(f"Failed to load pipeline '{model_name_or_path}'.")
return pipeline
if model is None:
raise RuntimeError(f"Failed to load model '{model_name_or_path}'.")
return model


def load_models_sd(
Expand Down
Empty file.
93 changes: 93 additions & 0 deletions src/invoke_training/model_merge/extract_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import tqdm
from peft.peft_model import PeftModel

# All original base model weights in a PeftModel have this prefix and suffix.
PEFT_BASE_LAYER_PREFIX = "base_model.model."
PEFT_BASE_LAYER_SUFFIX = ".base_layer.weight"


def get_patched_base_weights_from_peft_model(peft_model: PeftModel) -> dict[str, torch.Tensor]:
"""Get a state_dict containing the base model weights *thath are patched* in the provided PeftModel. I.e. only
return base model weights that have associated LoRa layers, but don't return the LoRA layers.
"""
state_dict = peft_model.state_dict()
out_state_dict: dict[str, torch.Tensor] = {}
for weight_name in state_dict:
# Weights that end with ".base_layer.weight" are the original weights for LoRA layers.
if weight_name.endswith(PEFT_BASE_LAYER_SUFFIX):
# Extract the base module name.
module_name = weight_name[: -len(PEFT_BASE_LAYER_SUFFIX)]
assert module_name.startswith(PEFT_BASE_LAYER_PREFIX)
module_name = module_name[len(PEFT_BASE_LAYER_PREFIX) :]

out_state_dict[module_name] = state_dict[weight_name]

return out_state_dict


def get_state_dict_diff(
state_dict_1: dict[str, torch.Tensor], state_dict_2: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Return the difference between two state_dicts: state_dict_1 - state_dict_2."""
return {key: state_dict_1[key] - state_dict_2[key] for key in state_dict_1}


@torch.no_grad()
def extract_lora_from_diffs(
diffs: dict[str, torch.Tensor], rank: int, clamp_quantile: float, out_dtype: torch.dtype
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
lora_weights = {}
for lora_name, mat in tqdm.tqdm(list(diffs.items())):
# Use full precision for the intermediate calculations.
mat = mat.to(torch.float32)

is_conv2d = False
if len(mat.shape) == 4: # Conv2D
is_conv2d = True
out_dim, in_dim, kernel_h, kernel_w = mat.shape
# Reshape to (out_dim, in_dim * kernel_h * kernel_w).
mat = mat.flatten(start_dim=1)
elif len(mat.shape) == 2: # Linear
out_dim, in_dim = mat.shape
else:
raise ValueError(f"Unexpected weight shape: {mat.shape}")

# LoRA rank cannot exceed the original dimensions.
assert rank < in_dim
assert rank < out_dim

u: torch.Tensor
s: torch.Tensor
v_h: torch.Tensor
u, s, v_h = torch.linalg.svd(mat)

# Apply the Eckart-Young-Mirsky theorem.
# https://en.wikipedia.org/wiki/Low-rank_approximation#Proof_of_Eckart%E2%80%93Young%E2%80%93Mirsky_theorem_(for_Frobenius_norm)
u = u[:, :rank]
s = s[:rank]
u = u @ torch.diag(s)

v_h = v_h[:rank, :]

# At this point, u is the lora_up (a.k.a. lora_B) weight, and v_h is the lora_down (a.k.a. lora_A) weight.
# The reason we don't use more appropriate variable names is to keep memory usage low - we want the old tensors
# to get cleaned up after each operation.

# Clamp the outliers.
dist = torch.cat([u.flatten(), v_h.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val

u = u.clamp(low_val, hi_val)
v_h = v_h.clamp(low_val, hi_val)

if is_conv2d:
u = u.reshape(out_dim, rank, 1, 1)
v_h = v_h.reshape(rank, in_dim, kernel_h, kernel_w)

u = u.to(dtype=out_dtype).contiguous()
v_h = v_h.to(dtype=out_dtype).contiguous()

lora_weights[lora_name] = (u, v_h)
return lora_weights
96 changes: 96 additions & 0 deletions src/invoke_training/model_merge/merge_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Literal

import torch
import tqdm

from invoke_training.model_merge.utils.normalize_weights import normalize_weights


@torch.no_grad()
def merge_models(
state_dicts: list[dict[str, torch.Tensor]], weights: list[float], merge_method: Literal["LERP", "SLERP"] = "LERP"
):
"""Merge multiple models into a single model.
Args:
state_dicts (list[dict[str, torch.Tensor]]): The state dicts to merge.
weights (list[float]): The weights for each state dict. The weights will be normalized to sum to 1.
merge_method (Literal["LERP", "SLERP"]): Merge method to use. Options:
- "LERP": Linear interpolation a.k.a. weighted sum.
- "SLERP": Spherical linear interpolation.
"""
if len(state_dicts) < 2:
raise ValueError("Must provide >=2 models to merge.")

if len(state_dicts) != len(weights):
raise ValueError("Must provide a weight for each model.")

if merge_method == "LERP":
merge_fn = lerp
elif merge_method == "SLERP":
merge_fn = slerp
else:
raise ValueError(f"Unknown merge method: {merge_method}")

normalized_weights = normalize_weights(weights)

out_state_dict: dict[str, torch.Tensor] = state_dicts[0].copy()
out_state_dict_weight = normalized_weights[0]
for state_dict, normalized_weight in zip(state_dicts[1:], normalized_weights[1:], strict=True):
if state_dict.keys() != out_state_dict.keys():
raise ValueError("State dicts must have the same keys.")

cur_pair_weights = normalize_weights([out_state_dict_weight, normalized_weight])
for key in tqdm.tqdm(out_state_dict.keys()):
out_state_dict[key] = merge_fn(out_state_dict[key], state_dict[key], cur_pair_weights[0])

# Update the weight of out_state_dict to be the sum of all state dicts merged so far.
out_state_dict_weight += normalized_weight

return out_state_dict


def lerp(a: torch.Tensor, b: torch.Tensor, weight_a: float) -> torch.Tensor:
"""Linear interpolation."""
return torch.lerp(a, b, (1.0 - weight_a))


def slerp(a: torch.Tensor, b: torch.Tensor, weight_a: float, dot_product_thres=0.9995, epsilon=1e-10):
"""Spherical linear interpolation."""
# TODO(ryand): For multi-dimensional matrices, it might be better to apply slerp on a subset of the dimensions
# (e.g. per-row), rather than treating the entire matrix as a single flattened vector.

# Normalize the vectors.
a_norm = torch.linalg.norm(a)
b_norm = torch.linalg.norm(b)
a_normalized = a / a_norm
b_normalized = b / b_norm

if a_norm < epsilon or b_norm < epsilon:
# If either vector is very small, fallback to lerp to avoid weird effects.
# TODO(ryand): Is fallback here necessary?
return lerp(a, b, weight_a)

# Dot product of the normalized vectors.
# We are effectively treating multi-dimensional tensors as flattened vectors.
dot_prod = torch.sum(a_normalized * b_normalized)

# If the absolute value of the dot product is almost 1, the vectors are ~colinear, so use lerp.
if torch.abs(dot_prod) > dot_product_thres:
return lerp(a, b, weight_a)

# Calculate initial angle between the vectors.
theta_0 = torch.acos(dot_prod)

# Angle at timestep t.
t = 1.0 - weight_a
theta_t = theta_0 * t

sin_theta_0 = torch.sin(theta_0)
sin_theta_t = torch.sin(theta_t)

s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
result = s0 * a + s1 * b

return result
Loading

0 comments on commit 0b44077

Please sign in to comment.