diff --git a/src/transformers/models/swiftformer/modeling_tf_swiftformer.py b/src/transformers/models/swiftformer/modeling_tf_swiftformer.py index d22f229c04471d..4381fab936ad24 100644 --- a/src/transformers/models/swiftformer/modeling_tf_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_tf_swiftformer.py @@ -769,6 +769,38 @@ def __init__(self, config: SwiftFormerConfig, **kwargs) -> None: else keras.layers.Identity(name="dist_head") ) + def hf_compute_loss(self, labels, logits): + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = keras.losses.MSE + if self.num_labels == 1: + loss = loss_fct(labels.squeeze(), logits.squeeze()) + else: + loss = loss_fct(labels, logits) + elif self.config.problem_type == "single_label_classification": + loss_fct = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=keras.losses.Reduction.NONE + ) + loss = loss_fct(labels, logits) + elif self.config.problem_type == "multi_label_classification": + loss_fct = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, + reduction=keras.losses.Reduction.NONE, + ) + loss = loss_fct(labels, logits) + else: + loss = None + + return loss + + @unpack_inputs @add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING) def call( @@ -809,33 +841,7 @@ def call( logits = (cls_out + distillation_out) / 2 # calculate loss - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = keras.losses.MSE - if self.num_labels == 1: - loss = loss_fct(labels.squeeze(), logits.squeeze()) - else: - loss = loss_fct(labels, logits) - elif self.config.problem_type == "single_label_classification": - loss_fct = keras.losses.SparseCategoricalCrossentropy( - from_logits=True, reduction=keras.losses.Reduction.NONE - ) - loss = loss_fct(labels, logits) - elif self.config.problem_type == "multi_label_classification": - loss_fct = keras.losses.SparseCategoricalCrossentropy( - from_logits=True, - reduction=keras.losses.Reduction.NONE, - ) - loss = loss_fct(labels, logits) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) if not return_dict: output = (logits,) + outputs[1:]