Skip to content

Commit

Permalink
Add KLDiv loss (#878)
Browse files Browse the repository at this point in the history
Add support for KL divergence loss.

Closes #870
  • Loading branch information
pglusacTT authored Dec 19, 2024
1 parent 12ce5de commit 9314040
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
23 changes: 23 additions & 0 deletions forge/forge/op/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,26 @@ def forward(self, prediction, labels):
loss = ReduceSum("r_sum", loss, -1)
loss = reduce_loss(self.reduction, loss)
return loss


class KLDivLoss(ForgeModule):
"""
KLDivLoss
KLDivLoss is sum(labels * (log(labels) - predictions), dim=-1), optionally reduced using ReduceAvg(default) or ReduceSum.
Note: This loss expects the input to be log probabilities.
"""

def __init__(self, name: str, reduction: str = "mean"):
super().__init__(name)
self.reduction = reduction
self.is_loss = True

@validate_shapes(min_dim=1, max_dim=2)
def forward(self, prediction, labels):
log_labels = Log("log", labels)
diff = Subtract("sub", log_labels, prediction)
product = Multiply("mul", labels, diff)
loss = reduce_loss(self.reduction, product)
return loss
31 changes: 31 additions & 0 deletions forge/test/mlir/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,37 @@ def test_cross_entropy_loss(prediction_shape):
assert torch.allclose(torch_loss_out, forge_loss_out[0], rtol=11e-3)


@pytest.mark.parametrize(
"prediction_shape",
[
(33,),
(128,),
(2, 2),
(3, 5),
(32, 32),
(33, 127),
(128, 20),
(128, 128),
],
)
@pytest.mark.parametrize("reduction", ["mean", "sum"])
def test_kl_div_loss(prediction_shape, reduction):
forge_loss = forge.op.loss.KLDivLoss("kl_div_loss", reduction=reduction)
torch_loss = torch.nn.KLDivLoss(reduction=reduction)

prediction = nn.functional.log_softmax(torch.randn(prediction_shape, requires_grad=True), dim=-1)
prediction_forge = forge.tensor.Tensor.create_from_torch(prediction)
target = torch.randn(prediction_shape)
# softmax the target
target = nn.functional.softmax(target, dim=-1)
target_forge = forge.tensor.Tensor.create_from_torch(target)

forge_loss = forge.compile(forge_loss, sample_inputs=[prediction_forge, target_forge])
forge_loss_out = forge_loss(prediction, target)[0]
torch_loss_out = torch_loss(prediction, target)
assert torch.allclose(torch_loss_out, forge_loss_out, rtol=5e-2)


@pytest.mark.parametrize(
"prediction_shape",
[
Expand Down

0 comments on commit 9314040

Please sign in to comment.