Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Sep 27, 2023
1 parent 0c53c2f commit 216213e
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions tests/benchmark/benchmark_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
from auto_gptq.utils import Perplexity
from memory_tracker import MemoryTracker
from tqdm import tqdm
from transformers import (
Expand All @@ -14,11 +15,11 @@
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
GPTQConfig
GPTQConfig,
)

from optimum.exporters import TasksManager
from auto_gptq.utils import Perplexity


def get_parser():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -102,7 +103,7 @@ def get_parser():
default=None,
help="Revision of the model to benchmark",
)

return parser


Expand Down Expand Up @@ -279,7 +280,7 @@ def benchmark_memory(
device = torch.device("cuda:0")
memory_tracker = MemoryTracker()

tokenizer = AutoTokenizer.from_pretrained(args.model,revision=args.revision, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision, use_fast=False)

if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -303,8 +304,16 @@ def benchmark_memory(

load_start = time.time_ns()
if args.gptq:
quantization_config = GPTQConfig(bits=4, disable_exllama=args.disable_exllama, disable_exllamav2=args.disable_exllamav2)
model = autoclass.from_pretrained(args.model,revision=args.revision, quantization_config=quantization_config, torch_dtype=torch.float16, device_map="auto")
quantization_config = GPTQConfig(
bits=4, disable_exllama=args.disable_exllama, disable_exllamav2=args.disable_exllamav2
)
model = autoclass.from_pretrained(
args.model,
revision=args.revision,
quantization_config=quantization_config,
torch_dtype=torch.float16,
device_map="auto",
)
elif args.bitsandbytes:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="fp4", bnb_4bit_compute_dtype=torch.float16
Expand Down Expand Up @@ -338,7 +347,7 @@ def benchmark_memory(
kernel = "autotogptq-cuda"
else:
kernel = "autogptq-cuda-old"

load_time = (load_end - load_start) * 1e-9
print(f"Model load time: {load_time:.1f} s")

Expand Down

0 comments on commit 216213e

Please sign in to comment.