Skip to content

Commit

Permalink
v0.1.2 (LoRA + optimizer V2 + refactor)
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Jan 8, 2024
1 parent be69389 commit 5d1a8e1
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 50 deletions.
27 changes: 8 additions & 19 deletions examples/lora/train_hqq_lora_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,11 @@
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False)
model.quantize_model(quant_config=quant_config)

######################################################################################
## BNB Quantize (for comparison)
# import transformers, torch
# model = transformers.AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path, load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path)

# from hqq.models.hf.llama import LlamaHQQ
# model.base_class = LlamaHQQ

#Add Peft
######################################################################################

from hqq.core.peft import PeftUtils

train_dtype = torch.bfloat16 #torch.float32
train_dtype = torch.bfloat16 #torch.float32 / torch.bfloat16
base_lora_params = {'lora_type':'default', 'r':32, 'lora_alpha':64, 'dropout':0.05, 'train_dtype':train_dtype}
lora_params = {'self_attn.q_proj': base_lora_params,
'self_attn.k_proj': base_lora_params,
Expand All @@ -43,7 +33,6 @@
'mlp.up_proj' : None,
'mlp.down_proj' : None}


#Apply LoRA
PeftUtils.add_lora(model, lora_params)

Expand All @@ -67,7 +56,7 @@
batch_size = 1
num_epochs = 1
grad_acc = 1
max_tokens = 256 #1024
max_tokens = 256
max_samples = 5000

#Warmup for torch compile
Expand All @@ -76,7 +65,6 @@
del out
cleanup()


#OpenAssistant
##########################################################################
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
Expand All @@ -89,8 +77,12 @@ def pre_process_chat(chat):
def assitant_prompt(prompt):
return '### Human:' + prompt + '\n### Assistant:'

# #Filter short samples
# dataset = Dataset.from_dict({'text':[dataset[i]['text'] for i in tqdm(range(len(dataset))) if len(dataset[i]['text'])>500]})
# dataset_val = Dataset.from_dict({'text':[dataset_val[i]['text'] for i in tqdm(range(len(dataset_val))) if len(dataset_val[i]['text'])>500]})

random.seed(100)
idx = random.sample(range(len(dataset)), max_samples)
idx = random.sample(range(len(dataset)), min(max_samples, len(dataset)))

dataset = Dataset.from_dict({'text':[pre_process_chat(dataset[i]['text']) for i in tqdm(idx)]})
dataset_val = Dataset.from_dict({'text':[pre_process_chat(dataset_val[i]['text']) for i in range(len(dataset_val))]})
Expand Down Expand Up @@ -121,7 +113,7 @@ def assitant_prompt(prompt):
fp16=train_dtype==torch.float32,
max_grad_norm=1.0,
save_steps=10000000,
lr_scheduler_type= "linear", #constant | linear
lr_scheduler_type= "linear",
)

#Wrap model to avoid accelerate issues
Expand Down Expand Up @@ -162,7 +154,6 @@ def parameters(self):

#Prediction/Eval
######################################################################################

#from #https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py
def compute_perplexity_batched(model, tokenizer, predictions, encodings=None, batch_size=1, add_start_token=True, device='cuda', max_length=None):
if tokenizer.pad_token is None and batch_size > 1:
Expand Down Expand Up @@ -232,7 +223,6 @@ def compute_perplexity_batched(model, tokenizer, predictions, encodings=None, ba
return np.mean(ppls)



tokenizer.add_bos_token = True
tokenizer.add_eos_token = False
model.eval()
Expand All @@ -241,4 +231,3 @@ def compute_perplexity_batched(model, tokenizer, predictions, encodings=None, ba
PeftUtils.cast_lora_weights(model, dtype=torch.half)

print('perplexity', compute_perplexity_batched(model=model, tokenizer=tokenizer, predictions=[s['text'] for s in dataset_val], batch_size=1, max_length=max_tokens))

152 changes: 125 additions & 27 deletions hqq/core/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,22 @@ def __init__(self, linear_layer, peft_config):
super().__init__()

#Device
self.device = next(linear_layer.parameters()).device
self.device = linear_layer.device if hasattr(linear_layer, 'device') else next(linear_layer.parameters()).device
self.train_dtype = peft_config['train_dtype'] if ('train_dtype' in peft_config) else torch.float

#Linear layer
self.linear_layer = linear_layer
self.in_features = linear_layer.in_features
self.out_features = linear_layer.out_features
self.bias = None if (linear_layer.bias is None) else linear_layer.bias.clone()

#Turn-off bias in the linear layer

#Bias
self.bias = None if (linear_layer.bias is None) else linear_layer.bias.clone()
self.linear_layer.bias = None
peft_config['train_bias'] = peft_config['train_bias'] if ('train_bias' in peft_config) else False
if(self.bias is not None):
self.bias = torch.nn.Parameter(self.bias, requires_grad=peft_config['train_bias'])
if((self.bias is None) and peft_config['train_bias']):
self.bias = torch.nn.Parameter(torch.zeros((self.out_features,), device=self.device, dtype=self.train_dtype), requires_grad=True)

#Dropout
if('dropout' in peft_config):
Expand All @@ -37,33 +42,44 @@ def __init__(self, linear_layer, peft_config):
self.lora_A = _get_dense_param(self.in_features, self.r, device=self.device, trainable=True, dtype=self.train_dtype)
self.lora_B = _get_dense_param(self.r, self.out_features, device=self.device, trainable=True, dtype=self.train_dtype)

#Init weights, as as the original LoRA implementation
torch.nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))
torch.nn.init.zeros_(self.lora_B)
#LoRA weights init
if('lora_init' in peft_config):
#Set lora init
assert (peft_config['lora_init']['lora_A'].shape[0], peft_config['lora_init']['lora_B'].shape[1])==(self.in_features, self.out_features), \
"Invalid init LoRA weight shapes. Expected: lora_A: " + str(self.in_features) + " x r , lora_B: r x " + str(self.out_features) + ")"
self.lora_A.data = peft_config['lora_init']['lora_A'].to(self.train_dtype).to(self.device)
self.lora_B.data = peft_config['lora_init']['lora_B'].to(self.train_dtype).to(self.device)
else:
#Init weights, as as the original LoRA implementation
torch.nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))
torch.nn.init.zeros_(self.lora_B)

def forward(self, x):
x_type = x.dtype
x_dtype = x.dtype

#Forward with base linear
out = self.linear_layer(x)

#LoRA
out += (torch.matmul(torch.matmul(self.peft_drop(x.to(self.lora_A.dtype)), self.lora_A), self.lora_B)*self.scaling).to(x_type)
out += (torch.matmul(torch.matmul(self.peft_drop(x.to(self.lora_A.dtype)), self.lora_A), self.lora_B)*self.scaling).to(x_dtype)

#Bias
if(self.bias is not None):
out += self.bias

out = out.to(x_dtype)

return out

def merge_and_quantize(self, quant_config):

#Get initial weights
W = self.linear_layer(torch.eye(self.in_features, device=self.device, dtype=torch.float16)).t()
W = self.linear_layer(torch.eye(self.in_features, device=self.device, dtype=torch.float16)).t() #== self.linear_layer.dequantize()

#Merge weights
W += (torch.matmul(self.lora_A.data, self.lora_B.data).t()*self.scaling).to(W.dtype)
W += (torch.matmul(self.lora_A.data, self.lora_B.data)*self.scaling).t().to(W.dtype)

#New HQQ layer
new_hqq_layer = HQQLinear(None, quant_config)
new_hqq_layer.bias = None if (self.bias is None) else self.bias.clone()
new_hqq_layer.quantize(W, **quant_config)
Expand All @@ -74,54 +90,136 @@ def cast(self, dtype=torch.float16):
self.lora_A.data = self.lora_A.data.to(dtype)
self.lora_B.data = self.lora_B.data.to(dtype)
if(self.bias is not None):
if(self.bias.requires_grad):
self.bias.data = self.bias.data.to(dtype)
else:
self.bias = self.bias.to(dtype)
self.bias.data = self.bias.data.to(dtype)
return self

def state_dict(self):
return {'lora_A':self.lora_A.data, 'lora_B':self.lora_B.data, 'scaling':self.scaling, 'bias':self.bias, 'peft_config':self.peft_config}
return {'lora_A':self.lora_A.data, 'lora_B':self.lora_B.data, 'scaling':self.scaling, 'bias':self.bias}

def load_state_dict(self, state_dict):
self.lora_A.data = state_dict['lora_A'].data.to(self.device)
self.lora_B.data = state_dict['lora_B'].data.to(self.device)
self.scaling = state_dict['scaling']
self.bias = state_dict['bias'] if ('bias' in state_dict) else None
self.bias = self.bias.to(self.device) if (self.bias is not None) else None
self.peft_config = state_dict['peft_config']
if(state_dict['bias'] is not None):
self.bias.data = state_dict['bias'].data.to(self.device)


#LoRA with fake quantization
class HQQLinearLoRAWithFakeQuant(HQQLinearLoRA):
def __init__(self, linear_layer, peft_config, quant_param):
def __init__(self, linear_layer, peft_config):
super(HQQLinearLoRAWithFakeQuant, self).__init__(linear_layer, peft_config)
self.quant_param = quant_param
self.quant_param = peft_config['quant_param']

#@torch.no_grad()
#@torch.compile()
def fake_quant(self, weight):
if(self.quant_param):
W_q, meta = Quantizer.quantize(weight, **self.quant_param, bitpack=False)
W_q, meta = Quantizer.quantize(weight, **self.quant_param, bitpack=False) #todo: clone() tensor
weight_est = Quantizer.dequantize(W_q, meta)
else:
weight_est = weight
return weight_est

def forward(self, x):
weight = self.linear_layer.dequantize() + (torch.matmul(self.lora_A, self.lora_B)*self.scaling).t()
weight = self.fake_quant(weight)
out = torch.matmul(x, weight.t())
x_dtype = x.dtype

#Get initial weights
W = self.linear_layer(torch.eye(self.in_features, device=self.device, dtype=x_dtype)).t() #== self.linear_layer.dequantize()

#Merge weights
W += (torch.matmul(self.lora_A, self.lora_B)*self.scaling).t().to(W.dtype)

#Fake quant
W = self.fake_quant(W).to(x_dtype)

#Matmul
out = torch.matmul(x, W.t())

#Bias
if(self.bias is not None):
out += self.bias

out = out.to(x_dtype)

return out

#Experimental
class HQQLinearGroupedProj(torch.nn.Module):
def __init__(self, linear_layer, peft_config):
super().__init__()

#Device
self.device = linear_layer.device if hasattr(linear_layer, 'device') else next(linear_layer.parameters()).device
self.train_dtype = peft_config['train_dtype'] if ('train_dtype' in peft_config) else torch.float

#Linear layer
self.linear_layer = linear_layer
self.in_features = linear_layer.in_features
self.out_features = linear_layer.out_features
self.bias = None if (linear_layer.bias is None) else linear_layer.bias.clone()

#Turn-off bias in the linear layer
self.linear_layer.bias = None

#Group proj
self.peft_config = peft_config
self.proj_size = peft_config['proj_size']
self.proj_num = peft_config['proj_num']
self.proj = torch.nn.Parameter(torch.stack([torch.eye(self.proj_size, dtype=self.train_dtype, device=self.device)]*self.proj_num))
if(peft_config['zero_trainable']):
self.linear_layer.meta['zero'] = torch.nn.Parameter(self.linear_layer.meta['zero'].to(self.train_dtype), requires_grad=True)

@torch.compile()
def forward(self, x):
x_dtype = x.dtype

#Forward with base linear
with torch.no_grad():
W = self.linear_layer.dequantize().clone()
#W = self.linear_layer(torch.eye(self.in_features, device=self.device, dtype=x_dtype)).t()
shape = W.shape

#Grouped proj
proj_b, gs = self.proj.shape[0], self.proj.shape[1]
W = torch.matmul(self.proj, W.reshape((proj_b, gs, -1)).to(self.proj.dtype)).to(x_dtype).reshape(shape)

_HQQ_LORA_CLASSES = [HQQLinearLoRA, HQQLinearLoRAWithFakeQuant]
_HQQ_LORA_MAPPING = {'default':HQQLinearLoRA, 'lora_with_fakequant':HQQLinearLoRAWithFakeQuant}
#Matmul
out = torch.matmul(x, W.t())

#Bias
if(self.bias is not None):
out += self.bias

out = out.to(x_dtype)

return out

def cast(self, dtype=torch.float16):
self.proj.data = self.proj.data.to(dtype)
self.linear_layer.meta['zero'] = self.linear_layer.meta['zero'].to(dtype)
if(self.bias is not None):
if(self.bias.requires_grad):
self.bias.data = self.bias.data.to(dtype)
else:
self.bias = self.bias.to(dtype)
return self

def state_dict(self):
return {'proj':self.proj.data, 'bias':self.bias, 'peft_config':self.peft_config}

def load_state_dict(self, state_dict):
self.proj.data = state_dict['proj'].data.to(self.device)
self.bias = state_dict['bias'] if ('bias' in state_dict) else None
self.bias = self.bias.to(self.device) if (self.bias is not None) else None
self.peft_config = state_dict['peft_config']


_HQQ_LORA_CLASSES = [HQQLinearLoRA, HQQLinearLoRAWithFakeQuant, HQQLinearGroupedProj]
_HQQ_LORA_MAPPING = {'default':HQQLinearLoRA, 'lora_with_fakequant':HQQLinearLoRAWithFakeQuant, 'grouped_proj':HQQLinearGroupedProj}

def is_hqq_lora_layer(layer):
return type(layer) in _HQQ_LORA_CLASSES

##################################################################################################################
def autoname_modules(model):
for name, module in model.named_modules():
Expand Down
7 changes: 4 additions & 3 deletions hqq/core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,17 @@ def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=Fa
if(optimize): scale, zero = Quantizer.optimize_weights(tensor=W, scale=scale, zero=zero, min_max=min_max, axis=axis)

#Quantize
scale, zero = scale.clone(), zero.clone() #Necessary for fake quantization backprop
W_q = torch.round(W*scale + zero).clamp(min_max[0], min_max[1])

#Store meta-data (we invert the scale for dequantization)
meta = {'nbits':nbits, 'group_size':group_size, 'shape':shape, 'scale':1./scale, 'zero':zero, 'axis':axis, 'packing':Quantizer.bit_to_packing[nbits]}

#Pack bits
if(bitpack):
W_q = Quantizer.pack[meta['packing']](W_q)
W_q = Quantizer.pack[meta['packing']](W_q)
else:
W_q = W_q.to(tensor.dtype)
meta['packing'] = None

#cleanup
Expand All @@ -88,7 +90,7 @@ def dequantize(cls, W_q, meta):
if((meta['group_size'] is not None) and (meta['nbits']==3)):
W_r = W_r[:meta['group_size']] if (meta['axis']==0) else W_r[:,:meta['group_size']]
else:
W_r = W_q
W_r = W_q.half()
W_r = ((W_r - meta['zero'])*meta['scale']).reshape(meta['shape'])
return W_r

Expand Down Expand Up @@ -290,7 +292,6 @@ def load_state_dict(self, state_dict):
if(self.in_gpu==False): self.cuda()
self.ready = True

@torch.inference_mode()
def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params):
quant_scale = scale_quant_params is not None
quant_zero = zero_quant_params is not None
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='hqq',
version='0.1.2.alpha', #0.1.1.post1
version='0.1.2',
description='Half-Quadratic Quantization (HQQ)',
url='https://github.com/mobiusml/hqq/',
author='Dr. Hicham Badri',
Expand Down

0 comments on commit 5d1a8e1

Please sign in to comment.