Skip to content

Commit

Permalink
Add BCE loss (#920)
Browse files Browse the repository at this point in the history
Add support for BCE loss and BCE with logits loss

Closes #872, closes #873
  • Loading branch information
pglusacTT authored Dec 20, 2024
1 parent bcc0548 commit 0da336f
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
70 changes: 69 additions & 1 deletion forge/forge/op/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from forge.op.tm import Broadcast, Unsqueeze
from ..module import ForgeModule
from .constant import Constant
from .eltwise_unary import Log, Abs
from .eltwise_unary import Log, Abs, Sigmoid
from .eltwise_binary import Add, GreaterEqual, Less, Subtract, Multiply
from .nn import Softmax
from .reduce import ReduceSum, ReduceAvg
Expand Down Expand Up @@ -223,3 +223,71 @@ def forward(self, prediction, labels):
# combine masks to get the final loss
loss = Add("loss", loss_lt, loss_ge)
return reduce_loss(self.reduction, loss)


def align_shape(target, reference, name):
unsqueezes = 0
while target.ndim() < reference.ndim():
target = Unsqueeze(f"unsqueeze_{name}_{unsqueezes}", target, unsqueezes)
unsqueezes += 1
for i in range(target.ndim()):
if target.shape[i] != reference.shape[i]:
target = Broadcast(f"broadcast_{name}_{i}", target, i, reference.shape[i])
return target


class BCELoss(ForgeModule):
"""
Binary Cross-Entropy Loss
loss = reduce(-1 * (labels * log(predictions) + (1 - labels) * log(1 - predictions)), dim=0)
"""

def __init__(self, name: str, reduction: str = "mean"):
super().__init__(name)
self.reduction = reduction
self.is_loss = True

@validate_shapes(min_dim=1, max_dim=2)
def forward(self, prediction, labels):
# BCE: -1 * (y * log(p) + (1 - y) * log(1 - p))
# First term: y * log(p)
log_prediction = Log("log", prediction)
first_term = Multiply("mul_lab_pred", labels, log_prediction)

one = Constant("one", constant=1.0)
one = align_shape(one, labels, "one")

# Second term: (1 - y) * log(1 - p)
one_minus_labels = Subtract("one_minus_labels", one, labels)
one_minus_prediction = Subtract("one_minus_prediction", one, prediction)
log_one_minus_prediction = Log("log_one_minus_prediction", one_minus_prediction)
second_term = Multiply("second_term", one_minus_labels, log_one_minus_prediction)

# -1 * (y * log(p) + (1 - y) * log(1 - p))
sum_terms = Add("sum_terms", first_term, second_term)
neg_one = Constant("neg_one", constant=-1.0)
neg_one = align_shape(neg_one, sum_terms, "neg_one")
negative_sum_terms = Multiply("negative_sum_terms", sum_terms, neg_one)
loss = reduce_loss(self.reduction, negative_sum_terms)
return loss


class BCEWithLogitsLoss(ForgeModule):
"""
Binary Cross-Entropy Loss with Logits
loss = BCELoss(Sigmoid(predictions), labels)
"""

def __init__(self, name: str, reduction: str = "mean"):
super().__init__(name)
self.reduction = reduction
self.is_loss = True
self.bce_loss = BCELoss("bce_loss", reduction=self.reduction)

@validate_shapes(min_dim=1, max_dim=2)
def forward(self, prediction, labels):
sigmoid_prediction = Sigmoid("sigmoid", prediction)
loss = self.bce_loss(sigmoid_prediction, labels)
return loss
62 changes: 62 additions & 0 deletions forge/test/mlir/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,65 @@ def test_huber_loss(prediction_shape, reduction):
torch_loss_out = torch_loss(prediction, target)

assert torch.allclose(torch_loss_out, forge_loss_out[0], rtol=5e-2)


@pytest.mark.parametrize(
"prediction_shape",
[
(33,),
(128,),
(2, 2),
(3, 5),
(32, 32),
(33, 127),
(128, 20),
(128, 128),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
def test_bce_loss(prediction_shape, reduction):
forge_loss = forge.op.loss.BCELoss("bce_loss", reduction=reduction)
torch_loss = torch.nn.BCELoss(reduction=reduction)

prediction = nn.functional.sigmoid(torch.randn(prediction_shape, requires_grad=True))
target = torch.rand(prediction_shape)

prediction_forge = forge.tensor.Tensor.create_from_torch(prediction)
target_forge = forge.tensor.Tensor.create_from_torch(target)

forge_loss = forge.compile(forge_loss, sample_inputs=[prediction_forge, target_forge])
forge_loss_out = forge_loss(prediction, target)
torch_loss_out = torch_loss(prediction, target)

assert torch.allclose(torch_loss_out, forge_loss_out[0], rtol=5e-2, atol=5e-3)


@pytest.mark.parametrize(
"prediction_shape",
[
(33,),
(128,),
(2, 2),
(3, 5),
(32, 32),
(33, 127),
(128, 20),
(128, 128),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
def test_bce_with_logits_loss(prediction_shape, reduction):
forge_loss = forge.op.loss.BCEWithLogitsLoss("bce_with_logits_loss", reduction=reduction)
torch_loss = torch.nn.BCEWithLogitsLoss(reduction=reduction)

prediction = torch.randn(prediction_shape, requires_grad=True)
target = torch.rand(prediction_shape)

prediction_forge = forge.tensor.Tensor.create_from_torch(prediction)
target_forge = forge.tensor.Tensor.create_from_torch(target)

forge_loss = forge.compile(forge_loss, sample_inputs=[prediction_forge, target_forge])
forge_loss_out = forge_loss(prediction, target)
torch_loss_out = torch_loss(prediction, target)

assert torch.allclose(torch_loss_out, forge_loss_out[0], rtol=5e-2, atol=5e-3)

0 comments on commit 0da336f

Please sign in to comment.