From 5b8c3083365fdf7ca3616c8375fef52ddfd37a59 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 7 Jan 2025 10:44:27 -0800 Subject: [PATCH] Move test file to right location (#1503) --- .../quantization/test_gptq_mt.py | 105 ++++++++---------- 1 file changed, 49 insertions(+), 56 deletions(-) rename test_gptq_mt.py => test/quantization/test_gptq_mt.py (87%) diff --git a/test_gptq_mt.py b/test/quantization/test_gptq_mt.py similarity index 87% rename from test_gptq_mt.py rename to test/quantization/test_gptq_mt.py index 7b15f55428..387293d5de 100644 --- a/test_gptq_mt.py +++ b/test/quantization/test_gptq_mt.py @@ -1,19 +1,18 @@ +from pathlib import Path -import unittest +import pytest import torch -import os -from pathlib import Path -from torchao._models.llama.tokenizer import get_tokenizer -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor -import sys -from safetensors.torch import load_file # Import safetensors loader import torch.nn.functional as F +from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.tokenizer import get_tokenizer +from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor from torchao.quantization.utils import _lm_eval_available + if _lm_eval_available: - + hqq_core = pytest.importorskip("hqq.core", reason="requires hqq") import lm_eval + try: # lm_eval version 0.4 from lm_eval.evaluator import evaluate from lm_eval.models.huggingface import HFLM as eval_wrapper @@ -49,8 +48,7 @@ def __init__( self.calibration_seq_length = calibration_seq_length self.input_prep_func = ( - input_prep_func if input_prep_func is not None - else lambda x: (x,) + input_prep_func if input_prep_func is not None else lambda x: (x,) ) self.pad_calibration_inputs = pad_calibration_inputs @@ -164,13 +162,9 @@ class TransformerEvalWrapper(InputRecorder): """ A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. """ + def __init__( - self, - model, - tokenizer, - max_seq_length, - input_prep_func=None, - device="cuda" + self, model, tokenizer, max_seq_length, input_prep_func=None, device="cuda" ): super().__init__(tokenizer, None) self._model = model @@ -181,23 +175,22 @@ def __init__( # need to take inps and convert to corrent input # for model self.input_prep_func = ( - input_prep_func if input_prep_func is not None - else lambda x: (x,) + input_prep_func if input_prep_func is not None else lambda x: (x,) ) def _model_call(self, inps): # print("Entering _model_call") # print(f"Input shape: {inps.shape}") - + input = self.input_prep_func(inps) # print(f"Processed input shapes: {[x.shape for x in input]}") - + input = [x.to(self._device) for x in input] # print(f"Inputs moved to device: {self._device}") - + max_seq_length = min(max(inps.size()), self.max_length) # print(f"Max sequence length: {max_seq_length}") - + # print("Setting up caches") with torch.device(self._device): # print(f"Device: {self._device}") @@ -205,17 +198,15 @@ def _model_call(self, inps): # print(f"Max sequence length: {max_seq_length}") self._model.setup_caches(self.batch_size, max_seq_length) # print("Caches set up") - + # print("Running model") # torch.save(input, "input.pt") logits = self._model(*input) # print(f"Model run complete. Logits shape: {logits.shape}") return logits - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception('unimplemented') + raise Exception("unimplemented") def run_eval(self, tasks, limit): logger.info(f"Starting evaluation on tasks: {tasks}") @@ -238,26 +229,21 @@ def run_eval(self, tasks, limit): logger.info("Starting evaluation") start_time = time.time() - + try: with torch.no_grad(): - result = evaluate( - self, - task_dict, - limit=limit, - verbosity= "DEBUG" - ) + result = evaluate(self, task_dict, limit=limit, verbosity="DEBUG") except Exception as e: logger.error(f"Evaluation failed: {e}") raise - + end_time = time.time() logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds") logger.info("Evaluation results:") for task, res in result["results"].items(): print(f"{task}: {res}") - + return result @@ -289,34 +275,41 @@ def run_eval(self, tasks, limit): input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False print("Recording inputs") -inputs = InputRecorder( +inputs = ( + InputRecorder( tokenizer, calibration_seq_length, input_prep_func, pad_calibration_inputs, model.config.vocab_size, device="cpu", - ).record_inputs( + ) + .record_inputs( calibration_tasks, calibration_limit, - ).get_inputs() + ) + .get_inputs() +) print("Inputs recorded") quantizer = Int4WeightOnlyGPTQQuantizer( - blocksize, - percdamp, - groupsize, - ) - + blocksize, + percdamp, + groupsize, +) + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) -multi = [MultiTensor([ inp for inp, _ in inputs]),MultiTensor([ inds for _, inds in inputs])] +multi = [ + MultiTensor([inp for inp, _ in inputs]), + MultiTensor([inds for _, inds in inputs]), +] print("Quantizing model") model = quantizer.quantize(model, multi).cuda() print("Model quantized") print("Saving model and fixing state dict") -regular_state_dict = model.state_dict()#defaultdict(torch.tensor) +regular_state_dict = model.state_dict() # defaultdict(torch.tensor) for key, value in model.state_dict().items(): if isinstance(value, MultiTensor): - regular_state_dict[key] = value.values[0] + regular_state_dict[key] = value.values[0] else: regular_state_dict[key] = value @@ -326,16 +319,16 @@ def run_eval(self, tasks, limit): del regular_state_dict[k] model.load_state_dict(regular_state_dict, assign=True) -torch.save(model.state_dict(), 'model.pth') +torch.save(model.state_dict(), "model.pth") print("Running evaluation") result = TransformerEvalWrapper( - model.to(device), # quantized model needs to run on cuda - tokenizer, - model.config.block_size, - prepare_inputs_for_model, - ).run_eval( - ["wikitext"], - None, - ) + model.to(device), # quantized model needs to run on cuda + tokenizer, + model.config.block_size, + prepare_inputs_for_model, +).run_eval( + ["wikitext"], + None, +) # wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}