diff --git a/losses.py b/losses.py index 17117d42..911fd51b 100644 --- a/losses.py +++ b/losses.py @@ -89,7 +89,13 @@ def forward(self, features, labels=None, mask=None): log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # compute mean of log-likelihood over positive - mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) + # avoid nan loss when there's one sample for a certain class, e.g., 0,1,...1 for bin-cls , this produce nan for 1st in Batch + # which also results in batch total loss as nan. such row should be dropped + pos_per_sample=mask.sum(1) #B + pos_per_sample[pos_per_sample<1e-6]=1.0 + mean_log_prob_pos = (mask * log_prob).sum(1) / pos_per_sample #mask.sum(1) + + #mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) # loss loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos