Skip to content

Commit

Permalink
remove unwanted test
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 16, 2023
1 parent afc56bb commit 400d129
Showing 1 changed file with 0 additions and 19 deletions.
19 changes: 0 additions & 19 deletions tests/models/deci/test_modeling_deci.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,25 +457,6 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_padding_right(self):
self.skipTest("Deci flash attention does not support right padding")

# Ignore copy
def test_load_balancing_loss(self):
r"""
Let's make sure we can actually compute the loss and do a backward on it.
"""

config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.output_router_logits = True
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
model = DeciForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask)
self.assertEqual(result.router_logits[0].shape, (91, config.num_experts_per_tok))
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(1, dtype=torch.float32))


@require_torch
class DeciIntegrationTest(unittest.TestCase):
@slow
Expand Down

0 comments on commit 400d129

Please sign in to comment.