Skip to content

Commit

Permalink
Fix load balancing loss func for mixtral (huggingface#28256)
Browse files Browse the repository at this point in the history
* Correct the implementation of auxiliary loss of mixtrtal

* correct the implementation of auxiliary loss of mixtrtal

* Implement a simpler calculation method

---------

Co-authored-by: zhangliangxu3 <[email protected]>
  • Loading branch information
2 people authored and AjayP13 committed Jan 22, 2024
1 parent c10861b commit 7bfd7f4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
6 changes: 1 addition & 5 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,15 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso

_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

# treat `top_k` as tokens (shape is `top_k X [batch_size X sequence_length]`)
selected_experts = selected_experts.reshape(-1)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
expert_mask = torch.max(expert_mask, dim=-2).values

# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)

overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(-1))
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts


Expand Down
2 changes: 1 addition & 1 deletion tests/models/mixtral/test_modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def test_load_balancing_loss(self):
model.eval()
result = model(input_ids, attention_mask=attention_mask)
self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts))
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(8, dtype=torch.float32))
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)


@require_torch
Expand Down

0 comments on commit 7bfd7f4

Please sign in to comment.