Skip to content

Commit

Permalink
mixtral uni test to float16 instead of fp8 (worker 8.9 or higher for …
Browse files Browse the repository at this point in the history
…fp8)
  • Loading branch information
dpower4 committed Dec 20, 2024
1 parent f36f4f7 commit 83c9981
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
@decorator.cuda_test
@decorator.pytorch_test
def test_pytorch_mixtral_8x7b():
"""Test pytorch-mixtral-8x7b benchmark for fp8 train and inference."""
"""Test pytorch-mixtral-8x7b benchmark for float16 train and inference."""
context = BenchmarkRegistry.create_benchmark_context(
'mixtral-8x7b',
platform=Platform.CUDA,
parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision fp8_e4m3 \
parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision float16 \
--hidden_size 1024 --max_position_embeddings 2048 --intermediate_size 3584 \
--model_action train inference',
framework=Framework.PYTORCH
Expand Down Expand Up @@ -59,7 +59,9 @@ def test_pytorch_mixtral_8x7b():
assert (benchmark.run_count == 1)
assert (benchmark.return_code == ReturnCode.SUCCESS)

for metric in ['fp8_e4m3_inference_step_time', 'fp8_e4m3_inference_throughput']:
for metric in [
'fp16_train_step_time', 'fp16_train_throughput', 'fp16_inference_step_time', 'fp16_inference_throughput'
]:
assert (len(benchmark.raw_data[metric]) == benchmark.run_count)
assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps)
assert (len(benchmark.result[metric]) == benchmark.run_count)

0 comments on commit 83c9981

Please sign in to comment.