Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

De-serializing a custom model in transformers v4.37.0 onwards doesn't work; weights aren't loaded from the saved checkpoint #29321

Closed
2 of 4 tasks
ashishu007 opened this issue Feb 27, 2024 · 4 comments

Comments

@ashishu007
Copy link

System Info

  • transformers version: 4.37.0
  • Platform: Linux-5.10.192-183.736.amzn2.x86_64-x86_64-with-glibc2.17
  • Python version: 3.8.11
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): not installed (NA)
  • Tensorflow version (GPU?): 2.12.0 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: False
  • Using distributed or parallel set-up in script?: False

Who can help?

@ArthurZucker @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I have a custom model that worked fine with transformers==4.33.1 but now fails with transformers==4.37.0. De-serializing the model from a checkpoint doesn't load the weights correctly, instead it randomly initializes the model with a long warning message. This appears to be a breaking change across minor versions that isn't backwards compatible.

A compact example of the custom model that raises the same warning is given below.

Q: What can I do to make my code work with recent transformers version?

Custom model class

import transformers
import tensorflow as tf
import numpy as np
from typing import List, Union, Tuple, Optional


class CustomBertClassifier(
    transformers.TFBertPreTrainedModel,
    transformers.modeling_tf_utils.TFSequenceClassificationLoss
):
    def __init__(self, config: transformers.BertConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = transformers.TFBertMainLayer(config, name="bert_lm")
        self.fc_layer = tf.keras.layers.Dense(units=768, name="fully_connected_layer")
        self.classifier = tf.keras.layers.Dense(
            units=2,
            name="classifier_head",
        )
        self.config = config

    def call(
        self,
        input_ids: Optional[transformers.modeling_tf_utils.TFModelInputType] = None,
        attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
        token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
        position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
        head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
        inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
        training: Optional[bool] = False,
    ) -> Union[
        transformers.modeling_tf_outputs.TFSequenceClassifierOutput,
        Tuple[tf.Tensor]
    ]:
        r"""
        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        pooled_output = outputs[1]
        fc_output = self.fc_layer(inputs=pooled_output)
        logits = self.classifier(inputs=fc_output)
        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return transformers.modeling_tf_outputs.TFSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

Warning Message

The following warning message appears while de-serializing the model:

Some layers from the model checkpoint at custom_model were not used when initializing CustomBertClassifier: ['bert_lm/encoder/layer_._7/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._2/output/dense/bias:0', 'bert_lm/encoder/layer_._8/attention/self/key/bias:0', 'bert_lm/encoder/layer_._9/attention/self/query/bias:0', 'bert_lm/encoder/layer_._9/output/dense/kernel:0', 'bert_lm/encoder/layer_._2/attention/self/key/bias:0', 'bert_lm/encoder/layer_._2/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._6/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._6/attention/self/value/bias:0', 'bert_lm/encoder/layer_._10/attention/self/value/bias:0', 'bert_lm/encoder/layer_._5/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._5/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._3/intermediate/dense/bias:0', 'classifier_head/bias:0', 'bert_lm/encoder/layer_._9/output/dense/bias:0', 'bert_lm/encoder/layer_._9/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/attention/self/query/bias:0', 'bert_lm/encoder/layer_._8/output/dense/bias:0', 'bert_lm/encoder/layer_._2/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._5/attention/self/value/bias:0', 'bert_lm/encoder/layer_._2/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._8/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._3/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._9/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._9/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._11/attention/self/query/bias:0', 'bert_lm/encoder/layer_._5/output/dense/kernel:0', 'bert_lm/encoder/layer_._3/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._11/output/dense/kernel:0', 'bert_lm/encoder/layer_._6/attention/self/key/bias:0', 'bert_lm/embeddings/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._10/attention/self/query/bias:0', 'bert_lm/encoder/layer_._4/output/dense/bias:0', 'bert_lm/encoder/layer_._3/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._1/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/output/dense/bias:0', 'bert_lm/encoder/layer_._1/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._4/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._7/output/dense/kernel:0', 'bert_lm/encoder/layer_._7/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._6/intermediate/dense/kernel:0', 'bert_lm/embeddings/LayerNorm/beta:0', 'bert_lm/encoder/layer_._0/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._10/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._2/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._5/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._7/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._8/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._11/intermediate/dense/kernel:0', 'fully_connected_layer/kernel:0', 'bert_lm/embeddings/token_type_embeddings/embeddings:0', 'bert_lm/encoder/layer_._10/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._0/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._3/attention/self/query/bias:0', 'bert_lm/encoder/layer_._9/attention/self/key/bias:0', 'bert_lm/encoder/layer_._4/attention/self/query/bias:0', 'bert_lm/encoder/layer_._0/attention/self/query/bias:0', 'bert_lm/encoder/layer_._10/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._1/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._9/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._10/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._11/attention/self/key/bias:0', 'bert_lm/encoder/layer_._2/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._8/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._1/attention/self/query/bias:0', 'bert_lm/encoder/layer_._5/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._5/attention/self/key/bias:0', 'bert_lm/encoder/layer_._0/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/self/value/bias:0', 'bert_lm/encoder/layer_._0/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._1/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._3/attention/self/value/bias:0', 'bert_lm/encoder/layer_._8/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._2/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._4/output/dense/kernel:0', 'bert_lm/encoder/layer_._6/output/dense/kernel:0', 'bert_lm/encoder/layer_._9/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._0/output/dense/bias:0', 'bert_lm/encoder/layer_._8/output/dense/kernel:0', 'bert_lm/encoder/layer_._2/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._11/output/dense/bias:0', 'bert_lm/encoder/layer_._8/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._11/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._0/attention/self/value/bias:0', 'bert_lm/encoder/layer_._6/attention/output/dense/kernel:0', 'bert_lm/embeddings/position_embeddings/embeddings:0', 'bert_lm/encoder/layer_._4/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._1/output/dense/bias:0', 'bert_lm/encoder/layer_._0/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._10/output/dense/kernel:0', 'bert_lm/encoder/layer_._9/attention/self/value/bias:0', 'bert_lm/encoder/layer_._10/output/dense/bias:0', 'bert_lm/encoder/layer_._9/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._7/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._10/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._3/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._5/attention/output/dense/bias:0', 'bert_lm/pooler/dense/bias:0', 'bert_lm/encoder/layer_._0/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._2/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._3/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._1/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._0/output/dense/kernel:0', 'bert_lm/encoder/layer_._1/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._1/attention/self/value/bias:0', 'bert_lm/encoder/layer_._3/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._4/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._3/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._4/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._8/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._0/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._8/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._8/attention/self/value/bias:0', 'bert_lm/encoder/layer_._9/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._9/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._10/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._4/attention/self/value/bias:0', 'bert_lm/encoder/layer_._10/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._1/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._11/output/LayerNorm/gamma:0', 'classifier_head/kernel:0', 'bert_lm/encoder/layer_._5/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._7/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._8/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._5/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._5/attention/self/query/bias:0', 'bert_lm/encoder/layer_._1/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/output/dense/kernel:0', 'bert_lm/pooler/dense/kernel:0', 'bert_lm/encoder/layer_._5/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._10/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._4/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._11/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._9/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._9/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._6/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._6/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._2/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._10/attention/self/key/bias:0', 'bert_lm/encoder/layer_._0/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._2/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._4/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._1/output/dense/kernel:0', 'bert_lm/encoder/layer_._3/output/dense/bias:0', 'bert_lm/encoder/layer_._3/output/dense/kernel:0', 'bert_lm/encoder/layer_._4/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._8/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._10/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._2/attention/self/query/bias:0', 'bert_lm/encoder/layer_._5/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._8/attention/output/LayerNorm/gamma:0', 'bert_lm/embeddings/word_embeddings/weight:0', 'bert_lm/encoder/layer_._9/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._7/attention/self/query/bias:0', 'bert_lm/encoder/layer_._0/attention/self/key/bias:0', 'bert_lm/encoder/layer_._0/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._0/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._7/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._2/output/dense/kernel:0', 'bert_lm/encoder/layer_._8/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._5/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._1/attention/self/key/bias:0', 'bert_lm/encoder/layer_._2/attention/self/value/bias:0', 'bert_lm/encoder/layer_._6/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._7/attention/self/key/bias:0', 'bert_lm/encoder/layer_._6/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._4/attention/self/key/bias:0', 'bert_lm/encoder/layer_._7/output/dense/bias:0', 'bert_lm/encoder/layer_._7/attention/self/value/bias:0', 'bert_lm/encoder/layer_._3/attention/self/key/bias:0', 'bert_lm/encoder/layer_._10/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._11/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._1/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._1/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._3/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._11/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._0/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._4/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._11/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._4/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._3/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._1/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._6/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._4/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._8/attention/self/query/bias:0', 'bert_lm/encoder/layer_._3/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._5/output/dense/bias:0', 'bert_lm/encoder/layer_._10/output/LayerNorm/beta:0', 'fully_connected_layer/bias:0', 'bert_lm/encoder/layer_._4/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._2/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._5/attention/output/LayerNorm/beta:0']
- This IS expected if you are initializing CustomBertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CustomBertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of CustomBertClassifier were initialized from the model checkpoint at custom_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use CustomBertClassifier for predictions without further training.

Expected behavior

The model de-serialization should work properly without the long warning message above.

Instead of the long warning message, the following small message should appear (confirming everything is working fine):

All model checkpoint layers were used when initializing CustomBertClassifier.

All the layers of CustomBertClassifier were initialized from the model checkpoint at custom_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use CustomBertClassifier for predictions without further training.
@Rocketknight1
Copy link
Member

Rocketknight1 commented Feb 27, 2024

Hi @ashishu007, the cause here is almost certainly this PR: #27794. In that PR, we transitioned from building our models by doing a forward pass with dummy inputs, to building our models with explicit build() methods on all layers. We did this to improve loading times, but also to fix a broad range of bugs that were created by the dummy input build process.

The problem, though, is that custom models that don't have a build() method will now not build their weights correctly, and so when transformers tries to load weights, the TF weights aren't present and so it can't load weights from the checkpoint into them.

The simple solution here is just to add a short build() method to your custom model. Note that it's acceptable to pass an input_shape of None to the build() method for most transformers models and layers - this is because they can usually figure out their weight shapes from the config.

@faph
Copy link

faph commented Feb 28, 2024

@Rocketknight1 thanks!

Not being too close to the details here, is the requirement to add a build() method documented somewhere? Just so we can use a reliable reference what this actually should look like.

@Rocketknight1
Copy link
Member

Hi @faph - I don't think so, but let me explain what I think is happening here and why the bug arose in this case.

When we load model weights in transformers, we first initialize the model with random weights, and then we load weight tensors from the weights file, which can be TF .h5, PyTorch .bin or .safetensors. The random initialization happens when the model build() is called. We call build() as part of the from_pretrained() method, because that method needs to load the model weights, but we don't automatically call build() when a model is initialized from a config.

In the past, build() passed dummy inputs through the network, but now it expects explicit build() methods. This difference should be invisible to users for the most part. However, the difference showed up here because of a few factors:

  • Your custom model inherits from TFBertPreTrainedModel. This means you were implicitly depending on the old build() behaviour to build your weights, maybe without realizing! You were also depending on our weight-loading code.
  • The BERT layer in your model is initialized from config, rather than from an existing pretrained checkpoint. This means that its weights will not be built until it is passed an input, or until its build() method is called.
  • The other layers are standard Keras layers which also follow this behaviour (if you look at the list of missing weights, you can see that those layers are also not loading correctly, e.g. fully_connected_layer/kernel:0)

Writing the build() method isn't too hard, though! You can use any of the existing build() methods in the codebase as a template. Something like this should work:

def build(self, input_shape=None):
    with tf.name_scope(self.bert.name):
        self.bert.build(None)
    with tf.name_scope(self.fc_layer.name):
        self.fc_layer.build((self.bert.config.hidden_size,))
    with tf.name_scope(self.classifier.name):
        self.classifier.build((768,))

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Apr 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants