From 16ba138650ff780b8393f48498c7339f52b872b8 Mon Sep 17 00:00:00 2001 From: moeiniamir Date: Tue, 25 Jun 2024 10:06:37 +0200 Subject: [PATCH] utils/gptq refactor --- llmfoundry/utils/gptq.py | 153 +++++++++++++++++++++++++-------------- 1 file changed, 97 insertions(+), 56 deletions(-) diff --git a/llmfoundry/utils/gptq.py b/llmfoundry/utils/gptq.py index a43f4e4c9f..6fbae6ae69 100644 --- a/llmfoundry/utils/gptq.py +++ b/llmfoundry/utils/gptq.py @@ -4,23 +4,43 @@ from pathlib import Path from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig from composer.utils import reproducibility -import torch import os -os.environ["CUDA_VISIBLE_DEVICES"] = "5" - +from auto_gptq import AutoGPTQForCausalLM +import transformers + + +def vmware_open_instruct_prepreprocess(example): + return {"messages": [ + { + 'role': 'user', + 'content': example['instruction'] + }, + { + "role": "assistant", + 'content': example['response'] + } + ]} + +PREPREPROCESS_MAP = { + "VMware/open-instruct": vmware_open_instruct_prepreprocess +} def get_calibration_data(cfg: omegaconf.DictConfig): MAX_SEQ_LEN = cfg.get("max_seq_len", 8192) NUM_EXAMPLES = cfg.get("num_examples", 512) MODEL_ID = cfg.get("model_id") DATASET = cfg.get("dataset", "HuggingFaceH4/ultrachat_200k") + SPLIT = cfg.get("split", "train_sft") + + prepreprocess = PREPREPROCESS_MAP.get(DATASET, lambda *_, **__: None) def preprocess(example): return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)} - dataset = load_dataset(DATASET, split="train_sft") + dataset = load_dataset(DATASET, split=SPLIT) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) ds = dataset.shuffle().select(range(NUM_EXAMPLES)) + ds = ds.map(prepreprocess) ds = ds.map(preprocess) examples = [ @@ -56,13 +76,15 @@ def chunk_examples(examples): return examples # %% -def quantize_and_save(cfg: omegaconf.DictConfig, examples, path, **kwargs): +def quantize(cfg: omegaconf.DictConfig, examples, **kwargs): quantize_config = BaseQuantizeConfig( bits=cfg.wbits, group_size=cfg.gs, desc_act=cfg.actorder, clip=cfg.get('clip', False), mse=cfg.get('mse', False), + sym=cfg.get('sym', True), + static_groups=cfg.get('static_groups', False), ) model = AutoGPTQForCausalLM.from_pretrained( @@ -71,86 +93,105 @@ def quantize_and_save(cfg: omegaconf.DictConfig, examples, path, **kwargs): **kwargs) model.quantize(examples) - if cfg.get('no_save', False): - return model - print(f"Saving gptq model to {path}") - model.save_quantized(path) - - import gc - del model - gc.collect() - return None + return model -def get_llama_marlin_factory(cfg, original_from_pretrained=None): - def get_llama_marlin(*args, **kwargs): - if original_from_pretrained: - import transformers - transformers.AutoModelForCausalLM.from_pretrained = original_from_pretrained - from auto_gptq import AutoGPTQForCausalLM - - gptq_save_dir = cfg.get('gptq_save_dir') - if gptq_save_dir: - gptq_save_dir = Path(gptq_save_dir) +def should_quantize(cfg: omegaconf.DictConfig): + if cfg.get('no_save', False): + print('no_save flag is set, not loading or saving the newly quantized model') + return True, None + + gptq_save_dir = cfg.get('gptq_save_dir') + if gptq_save_dir: + if Path(gptq_save_dir).is_dir(): + print(f"found gptq model at {gptq_save_dir}") + return False, Path(gptq_save_dir) else: + raise FileNotFoundError(f'{gptq_save_dir} not found, please provide a valid directory') + else: + try: gptq_save_dir = Path.home() / "saved" / \ f"{cfg.model_id.split('/')[-1]}-gptq{cfg.wbits}-{cfg.gs}-{cfg.actorder}" if cfg.get('suffix'): gptq_save_dir = gptq_save_dir.with_name( gptq_save_dir.name + f"-{cfg.get('suffix')}") + except: + raise ValueError("not enough information to determine if quantization is needed") + if Path(gptq_save_dir).is_dir(): + print(f"found gptq model at the derived {gptq_save_dir}") + return False, Path(gptq_save_dir) + else: + print(f"no gptq model found at the derived {gptq_save_dir}") + return True, Path(gptq_save_dir) + - # check if checkpoint exists as a folder - if not gptq_save_dir.is_dir() or cfg.get('no_save', False): - if cfg.get('no_save', False): - assert not cfg.get('use_marlin', False), "use_marlin=True is not supported with no_save=True" - print(f'{gptq_save_dir} not found, creating...') - print(cfg) +def get_llama_marlin_factory(cfg, original_from_pretrained=None): + def get_llama_marlin(*args, **kwargs): + if original_from_pretrained: + transformers.AutoModelForCausalLM.from_pretrained = original_from_pretrained + + assert not (cfg.get('use_marlin', False) and cfg.get('no_save', False)), "no_save and use_marlin cannot be used together" + + _should_quantize, potential_save_dir = should_quantize(cfg) + + if _should_quantize: + print('creating gptq model from following config:\n', cfg) if cfg.get('chunked', False): calibration_data = get_custom_calibration_data(cfg) else: calibration_data = get_calibration_data(cfg) - model = quantize_and_save(cfg, calibration_data, gptq_save_dir, **kwargs) + model = quantize(cfg, calibration_data, **kwargs) + if cfg.get('no_save', False): return model - - print(f"Loading gptq model from {gptq_save_dir}") + + print(f"saving gptq model to {potential_save_dir}") + model.save_quantized(potential_save_dir) + del model + import gc + gc.collect() + + print(f"Loading gptq model from {potential_save_dir}") model = AutoGPTQForCausalLM.from_quantized( - gptq_save_dir, use_marlin=cfg.get('use_marlin', False), **kwargs) + potential_save_dir, use_marlin=cfg.get('use_marlin', False), **kwargs) return model return get_llama_marlin -if __name__ == "!__main__": +if __name__ == "__main__": + # os.environ["CUDA_VISIBLE_DEVICES"] = "7" from omegaconf import OmegaConf cfg = OmegaConf.from_cli() cfg_default = OmegaConf.create( """ - model_id: meta-llama/Meta-Llama-3-8B-Instruct - wbits: 4 - gs: 128 - actorder: True - # no_save: True - # clip: True - seed: 1 - suffix: seed${seed}-clip - # dataset: /nfs/scistore19/alistgrp/amoeini/group_10_merged.txt - # chunked: true - # chunk_size: 1024 +model_id: meta-llama/Meta-Llama-3-8B-Instruct +wbits: 4 +gs: 128 +actorder: True +no_save: True +clip: True +mse: 1 +seed: 1 +# suffix: seed${seed}-hqq +# sym: False +# dataset: VMware/open-instruct +# split: train +# dataset: /nfs/scistore19/alistgrp/amoeini/group_10_merged.txt +# chunked: true +# chunk_size: 1024 +python_log_level: INFO """ ) cfg = OmegaConf.merge(cfg_default, cfg) if cfg.get('seed'): - cfg.suffix = f"seed{cfg.seed}" reproducibility.seed_all(cfg.seed) - import os - if os.environ.get("DEBUG"): - import logging - logging.basicConfig(level=logging.DEBUG) + import logging + logging.basicConfig(level=cfg.get('python_log_level', 'WARNING')) model = get_llama_marlin_factory(cfg)() -if __name__ == "__main__": - from transformers import AutoModelForCausalLM, GPTQConfig - gptq_config = GPTQConfig(bits=4, use_exllama=False) - model1 = AutoModelForCausalLM.from_pretrained("ISTA-DASLab/Llama-3-8B-Instruct-GPTQ-4bit", device_map="auto", quantization_config=gptq_config) - pass \ No newline at end of file +# if __name__ == "__main__": +# from transformers import AutoModelForCausalLM, GPTQConfig +# gptq_config = GPTQConfig(bits=4, use_exllama=False) +# model1 = AutoModelForCausalLM.from_pretrained("/nfs/scistore19/alistgrp/amoeini/saved/Meta-Llama-3-8B-Instruct-gptq4-128-True-seed1-clip", device_map="auto", quantization_config=gptq_config) +# pass \ No newline at end of file