diff --git a/Readme.md b/Readme.md index 48d505b..2932787 100755 --- a/Readme.md +++ b/Readme.md @@ -191,6 +191,36 @@ quant_config['self_attn.v_proj'] = q4_config model.quantize_model(quant_config=quant_config) ``` +### LoRA Training +You can use HQQ for lora training as follows: +```Python +#First, quantize/load a quantized HQQ model the +from hqq.core.peft import PeftUtils + +base_lora_params = {'lora_type':'default', 'r':32, 'lora_alpha':64, 'dropout':0.05, 'train_dtype':torch.bfloat16} +lora_params = {'self_attn.q_proj': base_lora_params, + 'self_attn.k_proj': base_lora_params, + 'self_attn.v_proj': base_lora_params, + 'self_attn.o_proj': base_lora_params, + 'mlp.gate_proj' : None, + 'mlp.up_proj' : None, + 'mlp.down_proj' : None} + + +PeftUtils.add_lora(model, lora_params) + +#Optional: faster but might not work on older GPUs +HQQLinear.set_backend(HQQBackend.PYTORCH_BACKPROP_COMPILE) + +#Train .... + +#Convert lora weights to the same model dtype for faster inference +model.eval() +PeftUtils.cast_lora_weights(model, dtype=torch.half) +``` + +We provide a complete example to train a model with HQQ/LoRA that you can find in ```examples/lora/train_hqq_lora_example.py```. + ### Examples We provide a variety of examples demonstrating model quantization across different backends within the ```examples``` directory. diff --git a/examples/lora/train_hqq_lora_example.py b/examples/lora/train_hqq_lora_example.py new file mode 100755 index 0000000..fd25d83 --- /dev/null +++ b/examples/lora/train_hqq_lora_example.py @@ -0,0 +1,244 @@ +#Settings +###################################################################################### +hf_auth = None #HuggingFace token +cache_path = '' #cache directory to store data + +#Chose a model +model_id = "meta-llama/Llama-2-7b-hf" +#model_id = "meta-llama/Llama-2-13b-hf" +#model_id = "meta-llama/Llama-2-70b-hf" + +#HQQ Quantize +###################################################################################### +from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer +model = HQQModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path) +tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path) + +#Quantize the model +from hqq.core.quantize import * +quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False) +model.quantize_model(quant_config=quant_config) + +###################################################################################### +## BNB Quantize (for comparison) +# import transformers, torch +# model = transformers.AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path, load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) +# tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path) + +# from hqq.models.hf.llama import LlamaHQQ +# model.base_class = LlamaHQQ + +#Add Peft +###################################################################################### + +from hqq.core.peft import PeftUtils + +train_dtype = torch.bfloat16 #torch.float32 +base_lora_params = {'lora_type':'default', 'r':32, 'lora_alpha':64, 'dropout':0.05, 'train_dtype':train_dtype} +lora_params = {'self_attn.q_proj': base_lora_params, + 'self_attn.k_proj': base_lora_params, + 'self_attn.v_proj': base_lora_params, + 'self_attn.o_proj': base_lora_params, + 'mlp.gate_proj' : None, + 'mlp.up_proj' : None, + 'mlp.down_proj' : None} + + +#Apply LoRA +PeftUtils.add_lora(model, lora_params) + +#Optional: faster but might not work properly on older GPUs +from hqq.core.quantize import * +HQQLinear.set_backend(HQQBackend.PYTORCH_BACKPROP_COMPILE) + +#Dataset +###################################################################################### +from datasets import load_dataset, Dataset +from tqdm import tqdm +import transformers +import numpy as np +import random + +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +tokenizer.add_bos_token = False +tokenizer.add_eos_token = False + +batch_size = 1 +num_epochs = 1 +grad_acc = 1 +max_tokens = 256 #1024 +max_samples = 5000 + +#Warmup for torch compile +with torch.no_grad(): + out = model(torch.ones((batch_size, max_tokens), dtype=torch.int32, device='cuda')) +del out +cleanup() + + +#OpenAssistant +########################################################################## +dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") +dataset_val = load_dataset("timdettmers/openassistant-guanaco", split="test") + +def pre_process_chat(chat): + #add proper chat preprocessing (bos/eos tokens, etc.) + return chat + +def assitant_prompt(prompt): + return '### Human:' + prompt + '\n### Assistant:' + +random.seed(100) +idx = random.sample(range(len(dataset)), max_samples) + +dataset = Dataset.from_dict({'text':[pre_process_chat(dataset[i]['text']) for i in tqdm(idx)]}) +dataset_val = Dataset.from_dict({'text':[pre_process_chat(dataset_val[i]['text']) for i in range(len(dataset_val))]}) + +##################################################################################### +#Train +from trl import SFTTrainer + +grad_acc = 2 +logging_st = 1 +max_steps = -1 +lr = 1e-4 +batch_size = 1 +n_epochs = 1 + +training_args = transformers.TrainingArguments( + output_dir='.', + per_device_train_batch_size=batch_size, + #per_device_eval_batch_size=batch_size, + gradient_accumulation_steps=grad_acc, + learning_rate=lr, + logging_steps=logging_st, + num_train_epochs=n_epochs, + max_steps=max_steps, + #evaluation_strategy = "epoch", + remove_unused_columns=False, + #logging_strategy="epoch", + fp16=train_dtype==torch.float32, + max_grad_norm=1.0, + save_steps=10000000, + lr_scheduler_type= "linear", #constant | linear +) + +#Wrap model to avoid accelerate issues +class WrappedModel(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args, **kwargs): + return self.model.forward(*args, **kwargs) + + def train(self): + self.model.train() + + def eval(self): + self.model.eval() + + def parameters(self): + return self.model.parameters() + +trainer = SFTTrainer( + model=WrappedModel(model), + tokenizer=tokenizer, + max_seq_length=max_tokens, + train_dataset=dataset, + eval_dataset=None, + peft_config=None, + args=training_args, + dataset_text_field="text", +) + +model.is_parallelizable = False +trainer.is_model_parallel = False +trainer.place_model_on_device = False + +model.train() +trainer.train() + +#Prediction/Eval +###################################################################################### + +#from #https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py +def compute_perplexity_batched(model, tokenizer, predictions, encodings=None, batch_size=1, add_start_token=True, device='cuda', max_length=None): + if tokenizer.pad_token is None and batch_size > 1: + existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) + # check that the model already has at least one special token defined + assert (len(existing_special_tokens) > 0), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." + # assign one of the special tokens to also be the pad token + tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) + + if add_start_token and max_length: + # leave room for token to be added: + assert (tokenizer.bos_token is not None), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" + max_tokenized_len = max_length - 1 + else: + max_tokenized_len = max_length + + + if(encodings is None): + encodings = tokenizer( + predictions, + add_special_tokens=False, + padding=True, + truncation=True if max_tokenized_len else False, + max_length=max_tokenized_len, + return_tensors="pt", + return_attention_mask=True).to(device) + + encoded_texts = encodings["input_ids"] + attn_masks = encodings["attention_mask"] + + # check that each input is long enough: + if add_start_token: + assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." + else: + assert torch.all( + torch.ge(attn_masks.sum(1), 2) + ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." + + ppls = [] + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + for start_index in tqdm(range(0, len(encoded_texts), batch_size)): + end_index = min(start_index + batch_size, len(encoded_texts)) + encoded_batch = encoded_texts[start_index:end_index] + attn_mask = attn_masks[start_index:end_index] + + if add_start_token: + bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) + encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) + attn_mask = torch.cat([torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1) + + labels = encoded_batch + + with torch.no_grad(): + out_logits = model(encoded_batch, attention_mask=attn_mask).logits + + shift_logits = out_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_attention_mask_batch = attn_mask[..., 1:].contiguous() + + perplexity_batch = torch.exp( + (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) + / shift_attention_mask_batch.sum(1)) + + ppls += perplexity_batch.tolist() + + return np.mean(ppls) + + + +tokenizer.add_bos_token = True +tokenizer.add_eos_token = False +model.eval() + +#Convert lora weights to the same model dtype for faster inference +PeftUtils.cast_lora_weights(model, dtype=torch.half) + +print('perplexity', compute_perplexity_batched(model=model, tokenizer=tokenizer, predictions=[s['text'] for s in dataset_val], batch_size=1, max_length=max_tokens)) + diff --git a/hqq/core/optimize.py b/hqq/core/optimize.py index 22e304a..6807bec 100755 --- a/hqq/core/optimize.py +++ b/hqq/core/optimize.py @@ -1,12 +1,150 @@ #Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023 ##################################################### - import torch import numpy as np +#re-estimate teh scale based on the inverse median +def update_scale_inverse_median(W_f, scale, zero, axis, min_max): + scale_rng = 2e4 + z_val = 1e-4 + delta = 1e-2 + + W_q = torch.round(W_f*scale + zero).clamp(min_max[0], min_max[1]) + + #Correct zero to avoid W_q==zero + zero_c = zero.clone() + zero_c_indx = torch.sum(1.*((W_q - zero)==0), axis=axis, keepdim=True)>0 + zero_c[zero_c_indx] = zero_c[zero_c_indx] + delta + + #Build scale tensor + W_f_c = W_f.clone() + W_f_c_mask = torch.abs(W_f_c)=0, torch.abs(scale_shifted)<=z_val)] = z_val + scale_shifted[torch.logical_and(scale_shifted<0, torch.abs(scale_shifted)<=z_val)] = -z_val + + err = torch.empty([N, n_clusters], dtype=dtype, device=device) + for i in range(N): + W_r = (W_q - zero)/scale_shifted[i][None,:] + err[i] = torch.abs(W_f - W_r).mean(axis=axis, keepdim=True) + + ind_r = torch.argmin(err, axis=axis).to(torch.int32) + ind_c = torch.arange(len(ind_r), dtype=torch.int32, device=device) + scale_b = scale_shifted[ind_r, ind_c] + + return scale_b + + #Proximal solver || W - dequantize(quantize(W))||_p^p @torch.inference_mode() -def optimize_weights_proximal(tensor, scale, zero, min_max, axis=0, device='cuda', opt_params={'lp_norm':0.7, 'beta':1e1, 'kappa':1.01, 'iters':20}, verbose=False): +def optimize_weights_proximal_v2(tensor, scale, zero, min_max, axis=0, device='cuda', + opt_params={'lp_norm':0.7, 'beta':1e1, 'kappa':1.01, 'iters':20, 'tol':0., 'early_stop':True, 'scale_gridsearch':False}, + verbose=False): + + #Params + lp_norm = max(opt_params['lp_norm'], 0.1) + beta = opt_params['beta'] + kappa = opt_params['kappa'] + iters = opt_params['iters'] + early_stop = opt_params['early_stop'] + tol = opt_params['tol'] + + #Check + assert lp_norm<=1., "lp_norm should be <=1" + assert beta>0., "beta should be > 0" + assert kappa>1., "kappa should be > 1" + assert iters>1, "iters should be > 1" + + #Cast/device + dtype = torch.float16 if (device=='cuda') else torch.float32 + W_f = tensor.to(dtype).to(device) + scale = scale.to(dtype).to(device) + zero = zero.to(dtype).to(device) + + if(lp_norm==1): + shrink_op = lambda x, beta: torch.sign(x)*torch.nn.functional.relu(torch.abs(x) - 1./beta) + else: + shrink_op = lambda x, beta,p=lp_norm: torch.sign(x)*torch.nn.functional.relu(torch.abs(x) - (1./beta)*torch.pow(torch.abs(x), p-1)) + + #Update scale: works slightly better. Tested on Llama2 only + if(opt_params['scale_gridsearch']): + scale = update_scale_grid_search(W_f, scale, zero, axis, min_max) + + #Optimize for zero-point + best_error = 1e4 + scale_prev, zero_prev = scale.clone(), zero.clone() + for i in range(iters): + W_q = torch.round(W_f*scale + zero).clamp(min_max[0], min_max[1]) + W_r = (W_q - zero)/scale + + #current_error = float(torch.pow(torch.abs(W_f - W_r), max(0.80, lp_norm)).mean()) + current_error = float(torch.abs(W_f - W_r).mean()) + + if(verbose): + print(i, np.round(current_error, 6)) + + if(early_stop): + if(best_error - current_error > tol): + best_error = current_error + scale_prev, zero_prev = scale.clone(), zero.clone() + else: + scale, zero = scale_prev.clone(), zero_prev.clone() + break + + W_e = shrink_op(W_f - W_r, beta) + zero = torch.mean(W_q - (W_f - W_e)*scale, axis=axis, keepdim=True) + beta *= kappa + + + #Clean-up + scale = scale.to(tensor.device) + zero = zero.to(tensor.device) + del W_f, W_q, W_r, W_e, scale_prev, zero_prev + torch.cuda.empty_cache() + + return scale, zero + +#Proximal solver || W - dequantize(quantize(W))||_p^p +@torch.inference_mode() +def optimize_weights_proximal_legacy(tensor, scale, zero, min_max, axis=0, device='cuda', opt_params={'lp_norm':0.7, 'beta':1e1, 'kappa':1.01, 'iters':20}, verbose=False): lp_norm, beta, kappa, iters = opt_params['lp_norm'], opt_params['beta'], opt_params['kappa'], opt_params['iters'] dtype = torch.float16 if (device=='cuda') else torch.float32 @@ -42,6 +180,9 @@ def optimize_weights_proximal(tensor, scale, zero, min_max, axis=0, device='cuda return scale, zero +#optimize_weights_proximal = optimize_weights_proximal_legacy +optimize_weights_proximal = optimize_weights_proximal_v2 + #SGD solver || W - dequantize(quantize(W))||_1 (p=1 only) def optimize_weights_autograd(tensor, scale, zero, min_max, axis=0, device='cuda', opt_params={'lr':2e-3, 'iters':2500}, verbose=False): W_f = tensor.to(device) diff --git a/hqq/core/peft.py b/hqq/core/peft.py new file mode 100755 index 0000000..0e0f225 --- /dev/null +++ b/hqq/core/peft.py @@ -0,0 +1,251 @@ +import torch +import numpy as np +from .quantize import HQQLinear, HQQBackend, Quantizer + +def _get_dense_param(in_features, out_features, device='cuda', trainable=True, dtype=torch.bfloat16): + W = torch.nn.Linear(in_features, out_features, bias=None).weight.data.t().to(dtype).to(device).contiguous() + return torch.nn.Parameter(W, requires_grad=trainable) + +class HQQLinearLoRA(torch.nn.Module): + def __init__(self, linear_layer, peft_config): + super().__init__() + + #Device + self.device = next(linear_layer.parameters()).device + self.train_dtype = peft_config['train_dtype'] if ('train_dtype' in peft_config) else torch.float + + #Linear layer + self.linear_layer = linear_layer + self.in_features = linear_layer.in_features + self.out_features = linear_layer.out_features + self.bias = None if (linear_layer.bias is None) else linear_layer.bias.clone() + + #Turn-off bias in the linear layer + self.linear_layer.bias = None + + #Dropout + if('dropout' in peft_config): + self.peft_drop = torch.nn.Dropout(p=peft_config['dropout']) if (peft_config['dropout']>0.) else torch.nn.Identity() + else: + self.peft_drop = torch.nn.Identity() + + #LoRA A/B + self.peft_config = peft_config + self.lora_alpha = peft_config['lora_alpha'] + self.r = peft_config['r'] + self.scaling = self.lora_alpha/self.r + self.lora_A = _get_dense_param(self.in_features, self.r, device=self.device, trainable=True, dtype=self.train_dtype) + self.lora_B = _get_dense_param(self.r, self.out_features, device=self.device, trainable=True, dtype=self.train_dtype) + + #Init weights, as as the original LoRA implementation + torch.nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5)) + torch.nn.init.zeros_(self.lora_B) + + def forward(self, x): + x_type = x.dtype + + #Forward with base linear + out = self.linear_layer(x) + + #LoRA + out += (torch.matmul(torch.matmul(self.peft_drop(x.to(self.lora_A.dtype)), self.lora_A), self.lora_B)*self.scaling).to(x_type) + + #Bias + if(self.bias is not None): + out += self.bias + + return out + + def merge_and_quantize(self, quant_config): + + #Get initial weights + W = self.linear_layer(torch.eye(self.in_features, device=self.device, dtype=torch.float16)).t() + + #Merge weights + W += (torch.matmul(self.lora_A.data, self.lora_B.data).t()*self.scaling).to(W.dtype) + + new_hqq_layer = HQQLinear(None, quant_config) + new_hqq_layer.bias = None if (self.bias is None) else self.bias.clone() + new_hqq_layer.quantize(W, **quant_config) + + return new_hqq_layer + + def cast(self, dtype=torch.float16): + self.lora_A.data = self.lora_A.data.to(dtype) + self.lora_B.data = self.lora_B.data.to(dtype) + if(self.bias is not None): + if(self.bias.requires_grad): + self.bias.data = self.bias.data.to(dtype) + else: + self.bias = self.bias.to(dtype) + return self + + def state_dict(self): + return {'lora_A':self.lora_A.data, 'lora_B':self.lora_B.data, 'scaling':self.scaling, 'bias':self.bias, 'peft_config':self.peft_config} + + def load_state_dict(self, state_dict): + self.lora_A.data = state_dict['lora_A'].data.to(self.device) + self.lora_B.data = state_dict['lora_B'].data.to(self.device) + self.scaling = state_dict['scaling'] + self.bias = state_dict['bias'] if ('bias' in state_dict) else None + self.bias = self.bias.to(self.device) if (self.bias is not None) else None + self.peft_config = state_dict['peft_config'] + + +#LoRA with fake quantization +class HQQLinearLoRAWithFakeQuant(HQQLinearLoRA): + def __init__(self, linear_layer, peft_config, quant_param): + super(HQQLinearLoRAWithFakeQuant, self).__init__(linear_layer, peft_config) + self.quant_param = quant_param + + def fake_quant(self, weight): + if(self.quant_param): + W_q, meta = Quantizer.quantize(weight, **self.quant_param, bitpack=False) + weight_est = Quantizer.dequantize(W_q, meta) + else: + weight_est = weight + return weight_est + + def forward(self, x): + weight = self.linear_layer.dequantize() + (torch.matmul(self.lora_A, self.lora_B)*self.scaling).t() + weight = self.fake_quant(weight) + out = torch.matmul(x, weight.t()) + #Bias + if(self.bias is not None): + out += self.bias + + return out + + +_HQQ_LORA_CLASSES = [HQQLinearLoRA, HQQLinearLoRAWithFakeQuant] +_HQQ_LORA_MAPPING = {'default':HQQLinearLoRA, 'lora_with_fakequant':HQQLinearLoRAWithFakeQuant} + +def is_hqq_lora_layer(layer): + return type(layer) in _HQQ_LORA_CLASSES +################################################################################################################## +def autoname_modules(model): + for name, module in model.named_modules(): + module.name = name + +#Patching functions +def patch_linear_add_peft(layer, patch_params): + _peft_config = patch_params + if(_peft_config): + lora_type = _peft_config['lora_type'] if ('lora_type' in _peft_config) else 'default' + new_layer = _HQQ_LORA_MAPPING[lora_type](layer, _peft_config) + else: + new_layer = layer + return new_layer + +def patch_linear_merge_peft(layer, patch_params): + _quant_config = patch_params + if(_quant_config): + new_layer = layer.merge_and_quantize(_quant_config) + del layer + cleanup() + else: + new_layer = layer + return new_layer + +def patch_linear_cast_peft(layer, patch_params): + if(is_hqq_lora_layer(layer)): + layer.cast(patch_params) + return layer + +#Putting it all together +class PeftUtils: + + @classmethod + def get_base_class(cls, model, base_class): + #Get base class + if((base_class is None) and hasattr(model, 'base_class')): + base_class = model.base_class + + assert (base_class is not None), "You need to provide the base HQQ class (LlamaHQQ, MixtralHQQ, etc.) as model.base_class or as an argument base_class=LlamaHQQ" + return base_class + + @classmethod + def add_lora(cls, model, lora_params, base_class=None, verbose=True): + + #Base classs + base_class = cls.get_base_class(model, base_class) + + #Freeze + for param in model.parameters(): + param.requires_grad = False + + #Patch + base_class.patch_linearlayers(model, patch_linear_add_peft, lora_params, verbose=verbose) + + #Rename modules + autoname_modules(model) + + #Default backprop backend + HQQLinear.set_backend(HQQBackend.PYTORCH_BACKPROP) + + @classmethod + def merge_lora(cls, model, merge_lora_params, base_class=None, verbose=True): + #Base classs + base_class = cls.get_base_class(model, base_class) + + #Patch + base_class.patch_linearlayers(model, patch_linear_merge_peft, merge_lora_params, verbose=verbose) + + @classmethod + def cast_lora_weights(cls, model, dtype, base_class=None, verbose=True): + #Base classs + base_class = cls.get_base_class(model, base_class) + + #Linear tags + linear_tags = base_class.get_linear_tags() + + #Patch + base_class.patch_linearlayers(model, + patch_linear_cast_peft, + dict([(linear_tag, dtype) for linear_tag in linear_tags]), + verbose=verbose) + + + @classmethod + def save_lora_weights(cls, model, filename, base_class=None, verbose=True): + #Base classs + base_class = cls.get_base_class(model, base_class) + + lora_global_params = {} + def _patch_linear_save_weights(layer, patch_params, return_layer=True): + if(is_hqq_lora_layer(layer)): + lora_global_params[layer.name] = layer.state_dict() + if(return_layer): return layer + + #Linear tags + linear_tags = base_class.get_linear_tags() + + #Patch + base_class.patch_linearlayers(model, + _patch_linear_save_weights, + dict([(linear_tag, None) for linear_tag in linear_tags]), + verbose=verbose) + + #save + torch.save(lora_global_params, filename) + + @classmethod + def load_lora_weights(cls, model, filename, base_class=None, verbose=True): + #Base classs + base_class = cls.get_base_class(model, base_class) + + lora_global_params = torch.load(file, map_location='cpu') + + def _patch_linear_load_weights(layer, patch_params, return_layer=True): + if(is_hqq_lora_layer(layer)): + layer.load_state_dict(lora_global_params[layer.name]) + if(return_layer): return layer + + #Linear tags + linear_tags = base_class.get_linear_tags() + + #Patch + base_class.patch_linearlayers(model, + _patch_linear_load_weights, + dict([(linear_tag, None) for linear_tag in linear_tags]), + verbose=verbose) diff --git a/hqq/core/quantize.py b/hqq/core/quantize.py index 010f086..3aa99cc 100755 --- a/hqq/core/quantize.py +++ b/hqq/core/quantize.py @@ -27,7 +27,7 @@ class Quantizer: '2bit_u8':BitPack.unpack_2bit_u8} @classmethod - def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0): + def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True): assert nbits in Quantizer.SUPPORTED_BITS, "nbits=" + str(nbits) + " not supported." assert axis in [0, 1], "axis should be either 0 or 1" if(group_size is not None): @@ -69,7 +69,10 @@ def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=Fa meta = {'nbits':nbits, 'group_size':group_size, 'shape':shape, 'scale':1./scale, 'zero':zero, 'axis':axis, 'packing':Quantizer.bit_to_packing[nbits]} #Pack bits - W_q = Quantizer.pack[meta['packing']](W_q) + if(bitpack): + W_q = Quantizer.pack[meta['packing']](W_q) + else: + meta['packing'] = None #cleanup del W, _min, _max @@ -80,11 +83,13 @@ def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=Fa #Main dequantization: bit_unpacking > (W_q - z)*s > reshape @classmethod def dequantize(cls, W_q, meta): - W_q_p = Quantizer.unpack[meta['packing']](W_q).half() - if((meta['group_size'] is not None) and (meta['nbits']==3)): - W_q_p = W_q_p[:meta['group_size']] if (meta['axis']==0) else W_q_p[:,:meta['group_size']] - W_r = ((W_q_p - meta['zero'])*meta['scale']).reshape(meta['shape']) - del W_q_p + if(meta['packing']): + W_r = Quantizer.unpack[meta['packing']](W_q).half() + if((meta['group_size'] is not None) and (meta['nbits']==3)): + W_r = W_r[:meta['group_size']] if (meta['axis']==0) else W_r[:,:meta['group_size']] + else: + W_r = W_q + W_r = ((W_r - meta['zero'])*meta['scale']).reshape(meta['shape']) return W_r @classmethod @@ -119,15 +124,113 @@ def cpu(cls, W_q, meta): try: import hqq_aten except: - print(colored('hqq_aten package not installed. HQQBackend.ATEN backend will not work unless you install the hqq_aten lib in hqq/kernels.', 'cyan')) + #print(colored('hqq_aten package not installed. HQQBackend.ATEN backend will not work unless you install the hqq_aten lib in hqq/kernels.', 'cyan')) hqq_aten = None from enum import Enum class HQQBackend(Enum): #Name of the forward functions - PYTORCH = "forward_pytorch" - PYTORCH_COMPILE = "forward_pytorch_compile" - ATEN = "forward_aten" + PYTORCH = "forward_pytorch" + PYTORCH_COMPILE = "forward_pytorch_compile" + PYTORCH_BACKPROP = "forward_pytorch_backprop" + PYTORCH_BACKPROP_COMPILE = "forward_pytorch_backprop_compile" + ATEN = "forward_aten" + + +#No cache: less memory, slower +class HQQMatmulNoCacheDeq(torch.autograd.Function): + + @staticmethod + def forward(x, dequantize, bias): + out = torch.matmul(x, dequantize().t()) + if(bias!=None): out += bias + return out + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, dequantize, bias = inputs + ctx.save_for_backward(x, bias) + ctx.dequantize = dequantize + + @staticmethod + def backward(ctx, grad_output): + x, bias = ctx.saved_tensors + dtype_out = grad_output.dtype + + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = torch.matmul(grad_output, ctx.dequantize()) + + # weight grad for frozen quantized weights not defined + # if ctx.needs_input_grad[1]: + # grad_weight = torch.matmul(grad_output.t(), x) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + + +class HQQMatmulNoCacheMul(torch.autograd.Function): + + @staticmethod + def forward(x, matmul, bias): + out = matmul(x, transpose=True) + if(bias!=None): out += bias + return out + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, matmul, bias = inputs + ctx.save_for_backward(x, bias) + ctx.matmul = matmul + + @staticmethod + def backward(ctx, grad_output): + x, bias = ctx.saved_tensors + + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = ctx.matmul(grad_output, transpose=False) + + # weight grad for frozen quantized weights not defined + # if ctx.needs_input_grad[1]: + # grad_weight = torch.matmul(grad_output.t(), x) + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + +#Cache dequantized tensor: Faster but needs more memory +class HQQMatmulCachedDeq(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, hqq_layer, bias): + weight_tmp = hqq_layer.dequantize() + out = torch.matmul(x, weight_tmp.t()) + if(bias!=None): out += bias + + ctx.save_for_backward(x, bias, weight_tmp) + return out + + @staticmethod + def backward(ctx, grad_output): + x, bias, weight_tmp = ctx.saved_tensors + + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = torch.matmul(grad_output, weight_tmp) + + del weight_tmp + + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias #Main linear layer class HQQLinear(torch.nn.Module): @@ -137,14 +240,24 @@ def __init__(self, linear_layer, quant_config, del_orig=True): super().__init__() self.ready = False self.in_gpu = False + self.device = None + self.bias = None self.quant_config = quant_config self.set_backend(HQQLinear.backend) #Default backend + if(linear_layer is not None): self.bias = None if (linear_layer.bias==None) else linear_layer.bias.half().cuda() self.quantize(linear_layer.weight.data, **quant_config) + if(del_orig): del linear_layer torch.cuda.empty_cache() - + + #Set backends + @classmethod + def set_backend(cls, backend: HQQBackend): + HQQLinear.backend = backend + cls.forward = getattr(cls, backend.value) + def cuda(self, device_n=0): if(self.in_gpu): return self.W_q, self.meta = Quantizer.cuda(self.W_q, self.meta, device_n) @@ -156,12 +269,14 @@ def cuda(self, device_n=0): if(self.bias is not None): self.bias = self.bias.half().cuda(device_n) + self.W_q = torch.nn.Parameter(self.W_q, requires_grad=False) + self.device = self.W_q.device self.in_gpu = True - def to(self, device): + def to(self, *args, **kwargs): pass - def half(self): + def half(self, *args, **kwargs): return self def state_dict(self): @@ -175,9 +290,12 @@ def load_state_dict(self, state_dict): if(self.in_gpu==False): self.cuda() self.ready = True + @torch.inference_mode() def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params): quant_scale = scale_quant_params is not None quant_zero = zero_quant_params is not None + + self.in_features, self.out_features = W.t().shape #Quantize W_q , meta = Quantizer.quantize(W, **weight_quant_params) @@ -192,46 +310,50 @@ def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params self.cuda() self.ready = True - @torch.inference_mode() def dequantize(self): assert self.ready, "model was not quantized" W_q, meta = self.W_q, self.meta + del_keys = [] if(meta['quant_scale']): meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') if(meta['quant_zero']): meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') + W_est = Quantizer.dequantize(W_q, meta) + #Cleanup for key in del_keys: del meta[key] return W_est - @classmethod - def set_backend(cls, backend: HQQBackend): - HQQLinear.backend = backend - cls.forward = getattr(cls, HQQLinear.backend.value) - - @torch.no_grad() - def forward_pytorch(self, x): - W_est = self.dequantize() - out = torch.matmul(x, W_est.t()) - if(self.bias!=None): out += self.bias - del W_est + def matmul(self, x, transpose=True): + weight = self.dequantize() + return torch.matmul(x, weight.t() if (transpose) else weight) + + @torch.compile() + def matmul_compile(self, *args, **kwargs): + return self.matmul(*args, **kwargs) + + def forward_pytorch_backprop(self, x): + return HQQMatmulNoCacheMul.apply(x, self.matmul, self.bias) + + def forward_pytorch_backprop_compile(self, x): + return HQQMatmulNoCacheMul.apply(x, self.matmul_compile, self.bias) + + def forward_pytorch(self, x): + out = torch.matmul(x, self.dequantize().t()) + if(self.bias is not None): + out += self.bias return out + @torch.compile() + def forward_pytorch_compile(self, x): + return self.forward_pytorch(x) + ############################################## #Experimental ############################################# - @torch.no_grad() - @torch.compile() - def forward_pytorch_compile(self, x): - W_est = self.dequantize() - out = torch.matmul(x, W_est.t()) - if(self.bias!=None): out += self.bias - del W_est - return out - - @torch.no_grad() + #Requires building the aten backend def forward_aten(self, x): empt = torch.empty([0]) W_q = self.W_q @@ -274,7 +396,8 @@ def forward_aten(self, x): def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=False): assert nbits in Quantizer.SUPPORTED_BITS, "nbits value not supported. Check Quantizer.SUPPORTED_BITS." - assert is_divisible(group_size, 8), "Invalid group_size param: the value should be a multiple of 8." + if(group_size is not None): + assert is_divisible(group_size, 8), "Invalid group_size param: the value should be a multiple of 8." weight_quant_params = {'nbits':nbits,'channel_wise':True, 'group_size':group_size, 'optimize':True, 'round_zero':True if nbits==4 else False} scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None zero_quant_params = {'nbits':8, 'channel_wise':False, 'group_size':None, 'optimize':False} if (quant_zero) else None diff --git a/hqq/engine/base.py b/hqq/engine/base.py index 0acb533..ca00109 100755 --- a/hqq/engine/base.py +++ b/hqq/engine/base.py @@ -26,6 +26,7 @@ def _is_quantizable(cls, model): @classmethod def _make_quantizable(cls, model, quantized): model.hqq_quantized = quantized + model.base_class = cls._get_hqq_class(model) @classmethod def _check_arch_support(cls, arg): diff --git a/hqq/engine/hf.py b/hqq/engine/hf.py index 525da73..6893fc4 100755 --- a/hqq/engine/hf.py +++ b/hqq/engine/hf.py @@ -34,6 +34,7 @@ def _make_quantizable(cls, model, quantized): model.to = lambda *args, **kwargs: model if(quantized) else model.to model.float = lambda *args, **kwargs: model if(quantized) else model.float model.half = lambda *args, **kwargs: model if(quantized) else model.half + model.base_class = cls._get_hqq_class(model) #Force loading the model on CPU and unquantized @classmethod diff --git a/hqq/engine/timm.py b/hqq/engine/timm.py index 4c7712c..19e7510 100755 --- a/hqq/engine/timm.py +++ b/hqq/engine/timm.py @@ -32,6 +32,7 @@ def _make_quantizable(cls, model, quantized): model.to = lambda *args, **kwargs: model if(quantized) else model.to model.float = lambda *args, **kwargs: model if(quantized) else model.float model.half = lambda *args, **kwargs: model if(quantized) else model.half + model.base_class = cls._get_hqq_class(model) @classmethod def _validate_params(cls, params:Dict): diff --git a/setup.py b/setup.py index e37836d..3e8a2e7 100755 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='hqq', - version='0.1.1.post1', + version='0.1.2.alpha', #0.1.1.post1 description='Half-Quadratic Quantization (HQQ)', url='https://github.com/mobiusml/hqq/', author='Dr. Hicham Badri',