Skip to content

Commit

Permalink
Skip tests on fbcode
Browse files Browse the repository at this point in the history
Differential Revision: D67982501

Pull Request resolved: #1532
  • Loading branch information
jainapurva authored Jan 9, 2025
1 parent 08cd260 commit b5b739b
Showing 1 changed file with 90 additions and 81 deletions.
171 changes: 90 additions & 81 deletions test/quantization/test_gptq_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import pytest
import torch
import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests

from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
from torchao.quantization.utils import _lm_eval_available
from torchao.utils import is_fbcode

if is_fbcode():
pytest.skip("Skipping the test in fbcode due to missing model and tokenizer files")

if _lm_eval_available:
hqq_core = pytest.importorskip("hqq.core", reason="requires hqq")
Expand Down Expand Up @@ -247,88 +252,92 @@ def run_eval(self, tasks, limit):
return result


precision = torch.bfloat16
device = "cuda"
print("Loading model")
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device="cpu")
model.eval()
print("Model loaded")
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
print("Tokenizer loaded")


blocksize = 128
percdamp = 0.01
groupsize = 64
calibration_tasks = ["wikitext"]
calibration_limit = None
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False
print("Recording inputs")
inputs = (
InputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
device="cpu",
def test_gptq_mt():
precision = torch.bfloat16
device = "cuda"
print("Loading model")
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device="cpu")
model.eval()
print("Model loaded")
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
.record_inputs(
calibration_tasks,
calibration_limit,
print("Tokenizer loaded")

blocksize = 128
percdamp = 0.01
groupsize = 64
calibration_tasks = ["wikitext"]
calibration_limit = None
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False
print("Recording inputs")
inputs = (
InputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
device="cpu",
)
.record_inputs(
calibration_tasks,
calibration_limit,
)
.get_inputs()
)
.get_inputs()
)
print("Inputs recorded")
quantizer = Int4WeightOnlyGPTQQuantizer(
blocksize,
percdamp,
groupsize,
)

model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
multi = [
MultiTensor([inp for inp, _ in inputs]),
MultiTensor([inds for _, inds in inputs]),
]
print("Quantizing model")
model = quantizer.quantize(model, multi).cuda()
print("Model quantized")
print("Saving model and fixing state dict")
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
for key, value in model.state_dict().items():
if isinstance(value, MultiTensor):
regular_state_dict[key] = value.values[0]
else:
regular_state_dict[key] = value

model = Transformer.from_name(checkpoint_path.parent.name)
remove = [k for k in regular_state_dict if "kv_cache" in k]
for k in remove:
del regular_state_dict[k]

model.load_state_dict(regular_state_dict, assign=True)
torch.save(model.state_dict(), "model.pth")
print("Running evaluation")
result = TransformerEvalWrapper(
model.to(device), # quantized model needs to run on cuda
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
).run_eval(
["wikitext"],
None,
)
print("Inputs recorded")
quantizer = Int4WeightOnlyGPTQQuantizer(
blocksize,
percdamp,
groupsize,
)

model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
multi = [
MultiTensor([inp for inp, _ in inputs]),
MultiTensor([inds for _, inds in inputs]),
]
print("Quantizing model")
model = quantizer.quantize(model, multi).cuda()
print("Model quantized")
print("Saving model and fixing state dict")
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
for key, value in model.state_dict().items():
if isinstance(value, MultiTensor):
regular_state_dict[key] = value.values[0]
else:
regular_state_dict[key] = value

model = Transformer.from_name(checkpoint_path.parent.name)
remove = [k for k in regular_state_dict if "kv_cache" in k]
for k in remove:
del regular_state_dict[k]

model.load_state_dict(regular_state_dict, assign=True)
torch.save(model.state_dict(), "model.pth")
print("Running evaluation")
TransformerEvalWrapper(
model.to(device), # quantized model needs to run on cuda
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
).run_eval(
["wikitext"],
None,
)


if __name__ == "__main__":
run_tests()

# wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

0 comments on commit b5b739b

Please sign in to comment.