diff --git a/losses.py b/losses.py index 17117d42..21b4dde9 100644 --- a/losses.py +++ b/losses.py @@ -86,7 +86,7 @@ def forward(self, features, labels=None, mask=None): # compute log_prob exp_logits = torch.exp(logits) * logits_mask - log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) # prevent computing log(0), which will produce Nan in the loss # compute mean of log-likelihood over positive mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)