From f36f4f7c4abaada34c7939e471afd8e7b0b0bc28 Mon Sep 17 00:00:00 2001 From: dilip patlolla Date: Thu, 19 Dec 2024 17:43:57 -0800 Subject: [PATCH] reduce mixtral dims to reduce vram req --- .../model_benchmarks/test_pytorch_mixtral.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py index ce0a1e3b8..db3bbbae9 100644 --- a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -15,14 +15,14 @@ @decorator.cuda_test @decorator.pytorch_test -@decorator.python_eol_test def test_pytorch_mixtral_8x7b(): - """Test pytorch-mixtral-8x7b benchmark for fp8 inference.""" + """Test pytorch-mixtral-8x7b benchmark for fp8 train and inference.""" context = BenchmarkRegistry.create_benchmark_context( 'mixtral-8x7b', platform=Platform.CUDA, parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision fp8_e4m3 \ - --model_action inference', + --hidden_size 1024 --max_position_embeddings 2048 --intermediate_size 3584 \ + --model_action train inference', framework=Framework.PYTORCH ) @@ -36,13 +36,13 @@ def test_pytorch_mixtral_8x7b(): assert (benchmark.name == 'pytorch-mixtral-8x7b') assert (benchmark.type == BenchmarkType.MODEL) - # Check predefined parameters of mixtral2 7b model. - assert (benchmark._args.hidden_size == 4096) + # Check predefined parameters of mixtral-8x7b model. + assert (benchmark._args.hidden_size == 1024) assert (benchmark._args.num_hidden_layers == 32) assert (benchmark._args.num_attention_heads == 32) assert (benchmark._args.num_key_value_heads == 8) - assert (benchmark._args.intermediate_size == 14336) - assert (benchmark._args.max_position_embeddings == 32768) + assert (benchmark._args.intermediate_size == 3584) + assert (benchmark._args.max_position_embeddings == 2048) assert (benchmark._args.router_aux_loss_coef == 0.02) # Check parameters specified in BenchmarkContext.