Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic committed Jan 7, 2025
1 parent 0d4c3fd commit 57340d2
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/tpu/test_quantization_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@ class GSM8KAccuracyTestConfig:

def get_model_args(self) -> str:
return (f"pretrained={self.model_name},"
"max_model_len=4096,max_num_seqs=128,enforce_eager=True")
"max_model_len=4096,max_num_seqs=128")


# NOTE(rob): Accuracy scores measured on GPUs.
# NOTE: Accuracy scores measured on GPUs.
ACCURACY_CONFIGS = [
GSM8KAccuracyTestConfig(
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
excepted_value=0.76), # no bias
GSM8KAccuracyTestConfig(
model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
excepted_value=0.66), # bias
# NOTE(rob): We cannot re-initialize VLLM in the same process for TPU,
# so only one of these tests can run in a single call to pytest. As
# a follow up, move this into the LM-EVAL section of the CI.
# GSM8KAccuracyTestConfig(
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
# excepted_value=0.66), # bias in QKV layers
]


Expand All @@ -37,7 +40,6 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
model_args=config.get_model_args(),
tasks="gsm8k",
batch_size="auto",
limit=1,
)

# EXPECTED_VALUE = config.excepted_value
Expand Down

0 comments on commit 57340d2

Please sign in to comment.