Skip to content

Commit

Permalink
Abstract out some logic from extract_lora_from_checkpoint.py to make …
Browse files Browse the repository at this point in the history
…it easier to support more submodels.
  • Loading branch information
RyanJDick committed May 31, 2024
1 parent 09506fa commit a77f612
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 91 deletions.
93 changes: 93 additions & 0 deletions src/invoke_training/model_merge/extract_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import tqdm
from peft.peft_model import PeftModel

# All original base model weights in a PeftModel have this prefix and suffix.
PEFT_BASE_LAYER_PREFIX = "base_model.model."
PEFT_BASE_LAYER_SUFFIX = ".base_layer.weight"


def get_patched_base_weights_from_peft_model(peft_model: PeftModel) -> dict[str, torch.Tensor]:
"""Get a state_dict containing the base model weights *thath are patched* in the provided PeftModel. I.e. only
return base model weights that have associated LoRa layers, but don't return the LoRA layers.
"""
state_dict = peft_model.state_dict()
out_state_dict: dict[str, torch.Tensor] = {}
for weight_name in state_dict:
# Weights that end with ".base_layer.weight" are the original weights for LoRA layers.
if weight_name.endswith(PEFT_BASE_LAYER_SUFFIX):
# Extract the base module name.
module_name = weight_name[: -len(PEFT_BASE_LAYER_SUFFIX)]
assert module_name.startswith(PEFT_BASE_LAYER_PREFIX)
module_name = module_name[len(PEFT_BASE_LAYER_PREFIX) :]

out_state_dict[module_name] = state_dict[weight_name]

return out_state_dict


def get_state_dict_diff(
state_dict_1: dict[str, torch.Tensor], state_dict_2: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Return the difference between two state_dicts: state_dict_1 - state_dict_2."""
return {key: state_dict_1[key] - state_dict_2[key] for key in state_dict_1}


@torch.no_grad()
def extract_lora_from_diffs(
diffs: dict[str, torch.Tensor], rank: int, clamp_quantile: float, out_dtype: torch.dtype
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
lora_weights = {}
for lora_name, mat in tqdm.tqdm(list(diffs.items())):
# Use full precision for the intermediate calculations.
mat = mat.to(torch.float32)

is_conv2d = False
if len(mat.shape) == 4: # Conv2D
is_conv2d = True
out_dim, in_dim, kernel_h, kernel_w = mat.shape
# Reshape to (out_dim, in_dim * kernel_h * kernel_w).
mat = mat.flatten(start_dim=1)
elif len(mat.shape) == 2: # Linear
out_dim, in_dim = mat.shape
else:
raise ValueError(f"Unexpected weight shape: {mat.shape}")

# LoRA rank cannot exceed the original dimensions.
assert rank < in_dim
assert rank < out_dim

u: torch.Tensor
s: torch.Tensor
v_h: torch.Tensor
u, s, v_h = torch.linalg.svd(mat)

# Apply the Eckart-Young-Mirsky theorem.
# https://en.wikipedia.org/wiki/Low-rank_approximation#Proof_of_Eckart%E2%80%93Young%E2%80%93Mirsky_theorem_(for_Frobenius_norm)
u = u[:, :rank]
s = s[:rank]
u = u @ torch.diag(s)

v_h = v_h[:rank, :]

# At this point, u is the lora_up (a.k.a. lora_B) weight, and v_h is the lora_down (a.k.a. lora_A) weight.
# The reason we don't use more appropriate variable names is to keep memory usage low - we want the old tensors
# to get cleaned up after each operation.

# Clamp the outliers.
dist = torch.cat([u.flatten(), v_h.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val

u = u.clamp(low_val, hi_val)
v_h = v_h.clamp(low_val, hi_val)

if is_conv2d:
u = u.reshape(out_dim, rank, 1, 1)
v_h = v_h.reshape(rank, in_dim, kernel_h, kernel_w)

u = u.to(dtype=out_dtype).contiguous()
v_h = v_h.to(dtype=out_dtype).contiguous()

lora_weights[lora_name] = (u, v_h)
return lora_weights
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,25 @@

import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Literal

import peft
import torch
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file
from tqdm import tqdm

from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
UNET_TARGET_MODULES,
save_sdxl_kohya_checkpoint,
)


def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)

if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name)
else:
torch.save(model, file_name)
from invoke_training.model_merge.extract_lora import (
PEFT_BASE_LAYER_PREFIX,
extract_lora_from_diffs,
get_patched_base_weights_from_peft_model,
get_state_dict_diff,
)


def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device:
Expand Down Expand Up @@ -65,66 +56,6 @@ def state_dict_to_device(state_dict: dict[str, torch.Tensor], device: torch.devi
return {k: v.to(device=device) for k, v in state_dict.items()}


@torch.no_grad()
def extract_lora_from_diffs(
diffs: dict[str, torch.Tensor], rank: int, clamp_quantile: float, out_dtype: torch.dtype
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
lora_weights = {}
for lora_name, mat in tqdm(list(diffs.items())):
# Use full precision for the intermediate calculations.
mat = mat.to(torch.float32)

is_conv2d = False
if len(mat.shape) == 4: # Conv2D
is_conv2d = True
out_dim, in_dim, kernel_h, kernel_w = mat.shape
# Reshape to (out_dim, in_dim * kernel_h * kernel_w).
mat = mat.flatten(start_dim=1)
elif len(mat.shape) == 2: # Linear
out_dim, in_dim = mat.shape
else:
raise ValueError(f"Unexpected weight shape: {mat.shape}")

# LoRA rank cannot exceed the original dimensions.
assert rank < in_dim
assert rank < out_dim

u: torch.Tensor
s: torch.Tensor
v_h: torch.Tensor
u, s, v_h = torch.linalg.svd(mat)

# Apply the Eckart-Young-Mirsky theorem.
# https://en.wikipedia.org/wiki/Low-rank_approximation#Proof_of_Eckart%E2%80%93Young%E2%80%93Mirsky_theorem_(for_Frobenius_norm)
u = u[:, :rank]
s = s[:rank]
u = u @ torch.diag(s)

v_h = v_h[:rank, :]

# At this point, u is the lora_up (a.k.a. lora_B) weight, and v_h is the lora_down (a.k.a. lora_A) weight.
# The reason we don't use more appropriate variable names is to keep memory usage low - we want the old tensors
# to get cleaned up after each operation.

# Clamp the outliers.
dist = torch.cat([u.flatten(), v_h.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val

u = u.clamp(low_val, hi_val)
v_h = v_h.clamp(low_val, hi_val)

if is_conv2d:
u = u.reshape(out_dim, rank, 1, 1)
v_h = v_h.reshape(rank, in_dim, kernel_h, kernel_w)

u = u.to(dtype=out_dtype).contiguous()
v_h = v_h.to(dtype=out_dtype).contiguous()

lora_weights[lora_name] = (u, v_h)
return lora_weights


@torch.no_grad()
def extract_lora(
logger: logging.Logger,
Expand Down Expand Up @@ -170,20 +101,10 @@ def extract_lora(
unet_tuned = peft.get_peft_model(unet_tuned, unet_lora_config)
unet_orig = peft.get_peft_model(unet_orig, unet_lora_config)

diffs: dict[str, torch.Tensor] = {}
state_dict_tuned = unet_tuned.state_dict()
state_dict_orig = unet_orig.state_dict()
peft_base_layer_suffix = ".base_layer.weight"
peft_base_layer_prefix = "base_model.model."
for weight_name in state_dict_tuned:
# Weights that end with ".base_layer.weight" are the original weights for LoRA layers.
if weight_name.endswith(peft_base_layer_suffix):
# Extract the base module name.
module_name = weight_name[: -len(peft_base_layer_suffix)]
assert module_name.startswith(peft_base_layer_prefix)
module_name = module_name[len(peft_base_layer_prefix) :]

diffs[module_name] = state_dict_tuned[weight_name] - state_dict_orig[weight_name]
unet_tuned_base_weights = get_patched_base_weights_from_peft_model(unet_tuned)
unet_orig_base_weights = get_patched_base_weights_from_peft_model(unet_orig)

diffs = get_state_dict_diff(unet_tuned_base_weights, unet_orig_base_weights)

# Clear tuned UNet to save memory.
# TODO(ryand): We also need to clear the state_dicts. Move the diff extraction to a separate function so that memory
Expand All @@ -201,8 +122,8 @@ def extract_lora(
# Prepare state dict for LoRA.
lora_state_dict = {}
for module_name, (lora_up, lora_down) in lora_weights.items():
lora_state_dict[peft_base_layer_prefix + module_name + ".lora_A.default.weight"] = lora_down
lora_state_dict[peft_base_layer_prefix + module_name + ".lora_B.default.weight"] = lora_up
lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + ".lora_A.default.weight"] = lora_down
lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + ".lora_B.default.weight"] = lora_up
# TODO(ryand): Double-check that this isn't needed with peft.
# lora_state_dict[peft_base_layer_suffix + module_name + ".alpha"] = torch.tensor(down_weight.size()[0])

Expand Down

0 comments on commit a77f612

Please sign in to comment.