-
-
Notifications
You must be signed in to change notification settings - Fork 220
/
lora.py
125 lines (91 loc) · 4.55 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from model import ExLlamaConfig, Ex4bitLinear
import torch
import json
from safetensors.torch import load_file as safe_load_file
from torch import load as load_file
class ExLlamaLora:
lora_config_path: str
lora_path: str
lora_r: int
lora_alpha: float
lora_scaling: float
config: ExLlamaConfig
tensors: dict[torch.tensor]
bias_ignored: bool
def __init__(self, model, lora_config_path, lora_path):
self.lora_config_path = lora_config_path
self.lora_path = lora_path
self.model = model
self.config = model.config
self.tensors = {}
self.bias_ignored = False
# Grab relevant items from LoRA config
with open(lora_config_path) as f:
read_config = json.load(f)
self.lora_r = read_config["r"]
self.lora_alpha = float(read_config["lora_alpha"])
self.lora_scaling = self.lora_alpha / self.lora_r
if "fan_in_fan_out" in read_config and read_config["fan_in_fan_out"]:
raise ValueError(" ## Error: fan_in_fan_out mode not supported.")
# Load LoRA weights
if self.lora_path.endswith(".safetensors"):
f = safe_load_file(self.lora_path, device = "cpu")
else:
f = load_file(self.lora_path, map_location = "cpu")
for key in f.keys():
tensor = f[key]
# Find target
i = key.find("model.layers.")
if i == -1: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}")
target_key = key[i:]
ks = target_key.split(".")
decoder_idx = int(ks[2])
decoder_part = ks[3]
decoder_layer = ks[4]
lora_half = ks[5]
if lora_half == "bias":
epsilon = 1e-6
if torch.max(tensor) > epsilon or torch.max(tensor) < -epsilon:
raise ValueError(f" ## Error: unsupported bias target {self.lora_path}: {key}")
self.bias_ignored = True
continue
target_module = self.model.layers[decoder_idx]
if decoder_part == "self_attn": target_module = target_module.self_attn
elif decoder_part == "mlp": target_module = target_module.mlp
else: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}")
if decoder_layer == "q_proj": target_module = target_module.q_proj
elif decoder_layer == "k_proj": target_module = target_module.k_proj
elif decoder_layer == "v_proj": target_module = target_module.v_proj
elif decoder_layer == "o_proj": target_module = target_module.o_proj
elif decoder_layer == "gate_proj": target_module = target_module.gate_proj
elif decoder_layer == "up_proj": target_module = target_module.up_proj
elif decoder_layer == "down_proj": target_module = target_module.down_proj
else: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}")
# Check that shape is compatible
assert isinstance(target_module, Ex4bitLinear)
if lora_half == "lora_A":
in_features = tensor.shape[1]
out_features = None
elif lora_half == "lora_B":
in_features = None
out_features = tensor.shape[0]
else: raise ValueError(f" ## Error: unsupported layer in {self.lora_path}: {key}")
if (in_features and in_features != target_module.in_features) or (out_features and out_features != target_module.out_features):
raise ValueError(f" ## Error: incompatible tensor shape in {self.lora_path}: {key}")
# For efficiency, transpose adapter instead of transposing state during inference
tensor = tensor.T.contiguous()
# Pre-scale
if lora_half == "lora_B" and self.lora_scaling != 1.0: tensor.mul_(self.lora_scaling)
# Check that dtype is compatible, or convert
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float16)
elif tensor.dtype == torch.float32:
tensor = tensor.to(torch.float16)
elif tensor.dtype == torch.float16:
pass
else: raise ValueError(f" ## Error: unsupported tensor dtype in {self.lora_path}")
# Move to target device
device = self.config.device_map.map(target_key)
tensor = tensor.to(device, non_blocking = True)
# Store adapter tensor
self.tensors[target_key] = tensor