From 3a1508216250d2a32f910ee1e32e0e39845689e0 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Thu, 16 May 2024 12:21:35 -0700 Subject: [PATCH] torch dmoe tests gpu oom (#1216) --- tests/models/layers/test_dmoe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py index 8668aa2ec9..a7393674dc 100644 --- a/tests/models/layers/test_dmoe.py +++ b/tests/models/layers/test_dmoe.py @@ -63,10 +63,10 @@ def _get_torch_dtype(fp16: bool, bf16: bool) -> Optional[torch.dtype]: ) @pytest.mark.gpu @pytest.mark.world_size(2) -@pytest.mark.parametrize('moe_num_experts', [1, 2, 8]) +@pytest.mark.parametrize('moe_num_experts', [1, 8]) @pytest.mark.parametrize('mlp_type', ['glu', 'mlp']) @pytest.mark.parametrize('moe_world_size', [1, 2]) -@pytest.mark.parametrize('moe_normalize_expert_weights', [1, 2.0]) +@pytest.mark.parametrize('moe_normalize_expert_weights', [1.0]) @pytest.mark.parametrize('two_d_input', [True, False]) def test_dmoe( moe_num_experts: int,