diff --git a/tests/models/deci/test_modeling_deci.py b/tests/models/deci/test_modeling_deci.py index 3ba2b4e4f16ba4..19ecc0b6991cef 100644 --- a/tests/models/deci/test_modeling_deci.py +++ b/tests/models/deci/test_modeling_deci.py @@ -457,25 +457,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_padding_right(self): self.skipTest("Deci flash attention does not support right padding") - # Ignore copy - def test_load_balancing_loss(self): - r""" - Let's make sure we can actually compute the loss and do a backward on it. - """ - - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.output_router_logits = True - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - model = DeciForCausalLM(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask) - self.assertEqual(result.router_logits[0].shape, (91, config.num_experts_per_tok)) - torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(1, dtype=torch.float32)) - - @require_torch class DeciIntegrationTest(unittest.TestCase): @slow