Skip to content

Commit

Permalink
Move test file to right location (#1503)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Jan 7, 2025
1 parent de3f812 commit 5b8c308
Showing 1 changed file with 49 additions and 56 deletions.
105 changes: 49 additions & 56 deletions test_gptq_mt.py → test/quantization/test_gptq_mt.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from pathlib import Path

import unittest
import pytest
import torch
import os
from pathlib import Path
from torchao._models.llama.tokenizer import get_tokenizer
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
import sys
from safetensors.torch import load_file # Import safetensors loader
import torch.nn.functional as F

from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
from torchao.quantization.utils import _lm_eval_available

if _lm_eval_available:

hqq_core = pytest.importorskip("hqq.core", reason="requires hqq")
import lm_eval

try: # lm_eval version 0.4
from lm_eval.evaluator import evaluate
from lm_eval.models.huggingface import HFLM as eval_wrapper
Expand Down Expand Up @@ -49,8 +48,7 @@ def __init__(
self.calibration_seq_length = calibration_seq_length

self.input_prep_func = (
input_prep_func if input_prep_func is not None
else lambda x: (x,)
input_prep_func if input_prep_func is not None else lambda x: (x,)
)

self.pad_calibration_inputs = pad_calibration_inputs
Expand Down Expand Up @@ -164,13 +162,9 @@ class TransformerEvalWrapper(InputRecorder):
"""
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
"""

def __init__(
self,
model,
tokenizer,
max_seq_length,
input_prep_func=None,
device="cuda"
self, model, tokenizer, max_seq_length, input_prep_func=None, device="cuda"
):
super().__init__(tokenizer, None)
self._model = model
Expand All @@ -181,41 +175,38 @@ def __init__(
# need to take inps and convert to corrent input
# for model
self.input_prep_func = (
input_prep_func if input_prep_func is not None
else lambda x: (x,)
input_prep_func if input_prep_func is not None else lambda x: (x,)
)

def _model_call(self, inps):
# print("Entering _model_call")
# print(f"Input shape: {inps.shape}")

input = self.input_prep_func(inps)
# print(f"Processed input shapes: {[x.shape for x in input]}")

input = [x.to(self._device) for x in input]
# print(f"Inputs moved to device: {self._device}")

max_seq_length = min(max(inps.size()), self.max_length)
# print(f"Max sequence length: {max_seq_length}")

# print("Setting up caches")
with torch.device(self._device):
# print(f"Device: {self._device}")
# print(f"Batch size: {self.batch_size}")
# print(f"Max sequence length: {max_seq_length}")
self._model.setup_caches(self.batch_size, max_seq_length)
# print("Caches set up")

# print("Running model")
# torch.save(input, "input.pt")
logits = self._model(*input)
# print(f"Model run complete. Logits shape: {logits.shape}")
return logits



def _model_generate(self, context, max_length, eos_token_id):
raise Exception('unimplemented')
raise Exception("unimplemented")

def run_eval(self, tasks, limit):
logger.info(f"Starting evaluation on tasks: {tasks}")
Expand All @@ -238,26 +229,21 @@ def run_eval(self, tasks, limit):

logger.info("Starting evaluation")
start_time = time.time()

try:
with torch.no_grad():
result = evaluate(
self,
task_dict,
limit=limit,
verbosity= "DEBUG"
)
result = evaluate(self, task_dict, limit=limit, verbosity="DEBUG")
except Exception as e:
logger.error(f"Evaluation failed: {e}")
raise

end_time = time.time()
logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds")

logger.info("Evaluation results:")
for task, res in result["results"].items():
print(f"{task}: {res}")

return result


Expand Down Expand Up @@ -289,34 +275,41 @@ def run_eval(self, tasks, limit):
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False
print("Recording inputs")
inputs = InputRecorder(
inputs = (
InputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
device="cpu",
).record_inputs(
)
.record_inputs(
calibration_tasks,
calibration_limit,
).get_inputs()
)
.get_inputs()
)
print("Inputs recorded")
quantizer = Int4WeightOnlyGPTQQuantizer(
blocksize,
percdamp,
groupsize,
)
blocksize,
percdamp,
groupsize,
)

model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
multi = [MultiTensor([ inp for inp, _ in inputs]),MultiTensor([ inds for _, inds in inputs])]
multi = [
MultiTensor([inp for inp, _ in inputs]),
MultiTensor([inds for _, inds in inputs]),
]
print("Quantizing model")
model = quantizer.quantize(model, multi).cuda()
print("Model quantized")
print("Saving model and fixing state dict")
regular_state_dict = model.state_dict()#defaultdict(torch.tensor)
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
for key, value in model.state_dict().items():
if isinstance(value, MultiTensor):
regular_state_dict[key] = value.values[0]
regular_state_dict[key] = value.values[0]
else:
regular_state_dict[key] = value

Expand All @@ -326,16 +319,16 @@ def run_eval(self, tasks, limit):
del regular_state_dict[k]

model.load_state_dict(regular_state_dict, assign=True)
torch.save(model.state_dict(), 'model.pth')
torch.save(model.state_dict(), "model.pth")
print("Running evaluation")
result = TransformerEvalWrapper(
model.to(device), # quantized model needs to run on cuda
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
).run_eval(
["wikitext"],
None,
)
model.to(device), # quantized model needs to run on cuda
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
).run_eval(
["wikitext"],
None,
)

# wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

0 comments on commit 5b8c308

Please sign in to comment.