-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
De-serializing a custom model in transformers v4.37.0 onwards doesn't work; weights aren't loaded from the saved checkpoint #29321
Comments
Hi @ashishu007, the cause here is almost certainly this PR: #27794. In that PR, we transitioned from building our models by doing a forward pass with dummy inputs, to building our models with explicit The problem, though, is that custom models that don't have a The simple solution here is just to add a short |
@Rocketknight1 thanks! Not being too close to the details here, is the requirement to add a |
Hi @faph - I don't think so, but let me explain what I think is happening here and why the bug arose in this case. When we load model weights in In the past,
Writing the def build(self, input_shape=None):
with tf.name_scope(self.bert.name):
self.bert.build(None)
with tf.name_scope(self.fc_layer.name):
self.fc_layer.build((self.bert.config.hidden_size,))
with tf.name_scope(self.classifier.name):
self.classifier.build((768,)) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
transformers
version: 4.37.0Who can help?
@ArthurZucker @Rocketknight1
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I have a custom model that worked fine with
transformers==4.33.1
but now fails withtransformers==4.37.0
. De-serializing the model from a checkpoint doesn't load the weights correctly, instead it randomly initializes the model with a long warning message. This appears to be a breaking change across minor versions that isn't backwards compatible.A compact example of the custom model that raises the same warning is given below.
Q: What can I do to make my code work with recent transformers version?
Custom model class
Warning Message
The following warning message appears while de-serializing the model:
Expected behavior
The model de-serialization should work properly without the long warning message above.
Instead of the long warning message, the following small message should appear (confirming everything is working fine):
All model checkpoint layers were used when initializing CustomBertClassifier. All the layers of CustomBertClassifier were initialized from the model checkpoint at custom_model. If your task is similar to the task the model of the checkpoint was trained on, you can already use CustomBertClassifier for predictions without further training.
The text was updated successfully, but these errors were encountered: