From b5b739b63752c4dd2603908ef66ee526821cc885 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 9 Jan 2025 15:13:06 -0800 Subject: [PATCH] Skip tests on fbcode Differential Revision: D67982501 Pull Request resolved: https://github.com/pytorch/ao/pull/1532 --- test/quantization/test_gptq_mt.py | 171 ++++++++++++++++-------------- 1 file changed, 90 insertions(+), 81 deletions(-) diff --git a/test/quantization/test_gptq_mt.py b/test/quantization/test_gptq_mt.py index 387293d5de..5d4e73ed61 100644 --- a/test/quantization/test_gptq_mt.py +++ b/test/quantization/test_gptq_mt.py @@ -3,11 +3,16 @@ import pytest import torch import torch.nn.functional as F +from torch.testing._internal.common_utils import run_tests from torchao._models.llama.model import Transformer, prepare_inputs_for_model from torchao._models.llama.tokenizer import get_tokenizer from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor from torchao.quantization.utils import _lm_eval_available +from torchao.utils import is_fbcode + +if is_fbcode(): + pytest.skip("Skipping the test in fbcode due to missing model and tokenizer files") if _lm_eval_available: hqq_core = pytest.importorskip("hqq.core", reason="requires hqq") @@ -247,88 +252,92 @@ def run_eval(self, tasks, limit): return result -precision = torch.bfloat16 -device = "cuda" -print("Loading model") -checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") -model = Transformer.from_name(checkpoint_path.parent.name) -checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) -model.load_state_dict(checkpoint, assign=True) -model = model.to(dtype=precision, device="cpu") -model.eval() -print("Model loaded") -tokenizer_path = checkpoint_path.parent / "tokenizer.model" -assert tokenizer_path.is_file(), tokenizer_path -tokenizer = get_tokenizer( # pyre-ignore[28] - tokenizer_path, - "Llama-2-7b-chat-hf", -) -print("Tokenizer loaded") - - -blocksize = 128 -percdamp = 0.01 -groupsize = 64 -calibration_tasks = ["wikitext"] -calibration_limit = None -calibration_seq_length = 100 -input_prep_func = prepare_inputs_for_model -pad_calibration_inputs = False -print("Recording inputs") -inputs = ( - InputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - device="cpu", +def test_gptq_mt(): + precision = torch.bfloat16 + device = "cuda" + print("Loading model") + checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device="cpu") + model.eval() + print("Model loaded") + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = get_tokenizer( # pyre-ignore[28] + tokenizer_path, + "Llama-2-7b-chat-hf", ) - .record_inputs( - calibration_tasks, - calibration_limit, + print("Tokenizer loaded") + + blocksize = 128 + percdamp = 0.01 + groupsize = 64 + calibration_tasks = ["wikitext"] + calibration_limit = None + calibration_seq_length = 100 + input_prep_func = prepare_inputs_for_model + pad_calibration_inputs = False + print("Recording inputs") + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + device="cpu", + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() ) - .get_inputs() -) -print("Inputs recorded") -quantizer = Int4WeightOnlyGPTQQuantizer( - blocksize, - percdamp, - groupsize, -) - -model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) -multi = [ - MultiTensor([inp for inp, _ in inputs]), - MultiTensor([inds for _, inds in inputs]), -] -print("Quantizing model") -model = quantizer.quantize(model, multi).cuda() -print("Model quantized") -print("Saving model and fixing state dict") -regular_state_dict = model.state_dict() # defaultdict(torch.tensor) -for key, value in model.state_dict().items(): - if isinstance(value, MultiTensor): - regular_state_dict[key] = value.values[0] - else: - regular_state_dict[key] = value - -model = Transformer.from_name(checkpoint_path.parent.name) -remove = [k for k in regular_state_dict if "kv_cache" in k] -for k in remove: - del regular_state_dict[k] - -model.load_state_dict(regular_state_dict, assign=True) -torch.save(model.state_dict(), "model.pth") -print("Running evaluation") -result = TransformerEvalWrapper( - model.to(device), # quantized model needs to run on cuda - tokenizer, - model.config.block_size, - prepare_inputs_for_model, -).run_eval( - ["wikitext"], - None, -) + print("Inputs recorded") + quantizer = Int4WeightOnlyGPTQQuantizer( + blocksize, + percdamp, + groupsize, + ) + + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + multi = [ + MultiTensor([inp for inp, _ in inputs]), + MultiTensor([inds for _, inds in inputs]), + ] + print("Quantizing model") + model = quantizer.quantize(model, multi).cuda() + print("Model quantized") + print("Saving model and fixing state dict") + regular_state_dict = model.state_dict() # defaultdict(torch.tensor) + for key, value in model.state_dict().items(): + if isinstance(value, MultiTensor): + regular_state_dict[key] = value.values[0] + else: + regular_state_dict[key] = value + + model = Transformer.from_name(checkpoint_path.parent.name) + remove = [k for k in regular_state_dict if "kv_cache" in k] + for k in remove: + del regular_state_dict[k] + + model.load_state_dict(regular_state_dict, assign=True) + torch.save(model.state_dict(), "model.pth") + print("Running evaluation") + TransformerEvalWrapper( + model.to(device), # quantized model needs to run on cuda + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + ).run_eval( + ["wikitext"], + None, + ) + + +if __name__ == "__main__": + run_tests() # wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}