Skip to content

Commit

Permalink
Extract loss logic to hf_compute_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
joaocmd committed Apr 16, 2024
1 parent 48f7a0d commit 200dbd0
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions src/transformers/models/swiftformer/modeling_tf_swiftformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,38 @@ def __init__(self, config: SwiftFormerConfig, **kwargs) -> None:
else keras.layers.Identity(name="dist_head")
)

def hf_compute_loss(self, labels, logits):
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = keras.losses.MSE
if self.num_labels == 1:
loss = loss_fct(labels.squeeze(), logits.squeeze())
else:
loss = loss_fct(labels, logits)
elif self.config.problem_type == "single_label_classification":
loss_fct = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=keras.losses.Reduction.NONE
)
loss = loss_fct(labels, logits)
elif self.config.problem_type == "multi_label_classification":
loss_fct = keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=keras.losses.Reduction.NONE,
)
loss = loss_fct(labels, logits)
else:
loss = None

return loss


@unpack_inputs
@add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING)
def call(
Expand Down Expand Up @@ -809,33 +841,7 @@ def call(
logits = (cls_out + distillation_out) / 2

# calculate loss
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = keras.losses.MSE
if self.num_labels == 1:
loss = loss_fct(labels.squeeze(), logits.squeeze())
else:
loss = loss_fct(labels, logits)
elif self.config.problem_type == "single_label_classification":
loss_fct = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=keras.losses.Reduction.NONE
)
loss = loss_fct(labels, logits)
elif self.config.problem_type == "multi_label_classification":
loss_fct = keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=keras.losses.Reduction.NONE,
)
loss = loss_fct(labels, logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down

0 comments on commit 200dbd0

Please sign in to comment.