diff --git a/tests/benchmark/benchmark_gptq.py b/tests/benchmark/benchmark_gptq.py index 369e6922e8e..45fdc262cee 100644 --- a/tests/benchmark/benchmark_gptq.py +++ b/tests/benchmark/benchmark_gptq.py @@ -5,6 +5,7 @@ import numpy as np import torch +from auto_gptq.utils import Perplexity from memory_tracker import MemoryTracker from tqdm import tqdm from transformers import ( @@ -14,11 +15,11 @@ AutoTokenizer, BitsAndBytesConfig, GenerationConfig, - GPTQConfig + GPTQConfig, ) from optimum.exporters import TasksManager -from auto_gptq.utils import Perplexity + def get_parser(): parser = argparse.ArgumentParser() @@ -102,7 +103,7 @@ def get_parser(): default=None, help="Revision of the model to benchmark", ) - + return parser @@ -279,7 +280,7 @@ def benchmark_memory( device = torch.device("cuda:0") memory_tracker = MemoryTracker() -tokenizer = AutoTokenizer.from_pretrained(args.model,revision=args.revision, use_fast=False) +tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision, use_fast=False) if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -303,8 +304,16 @@ def benchmark_memory( load_start = time.time_ns() if args.gptq: - quantization_config = GPTQConfig(bits=4, disable_exllama=args.disable_exllama, disable_exllamav2=args.disable_exllamav2) - model = autoclass.from_pretrained(args.model,revision=args.revision, quantization_config=quantization_config, torch_dtype=torch.float16, device_map="auto") + quantization_config = GPTQConfig( + bits=4, disable_exllama=args.disable_exllama, disable_exllamav2=args.disable_exllamav2 + ) + model = autoclass.from_pretrained( + args.model, + revision=args.revision, + quantization_config=quantization_config, + torch_dtype=torch.float16, + device_map="auto", + ) elif args.bitsandbytes: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="fp4", bnb_4bit_compute_dtype=torch.float16 @@ -338,7 +347,7 @@ def benchmark_memory( kernel = "autotogptq-cuda" else: kernel = "autogptq-cuda-old" - + load_time = (load_end - load_start) * 1e-9 print(f"Model load time: {load_time:.1f} s")