-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformers incompatible with master (head of trunk) tensorflow & keras 3 #28296
Comments
Here's what needs to be done: Some of this may break keras2 operation, I began adding version checks but have not had time to do it properly. I had to disable jit_compile, because I was getting XLA-related errors and this was an easy way out; I need to investigate and fix that problem as well. This gets me as far as being able to train for at least one epoch. Loss values seem to be off but the model loads and trains and loss goes down with time. |
I'll just keep talking to myself here, nevermind me. It trains, apparently correctly, on multiple GPUs (using tf.distribute.MirroredStrategy) and with XLA enabled. Reported loss is multiplied by the number of GPUs, and I can't quite work out why. The bigger issue, however, is that mixed precision is broken:
Going to tackle this one today. |
thanks! cc @Rocketknight1 |
Ok, the mixed precision issue from my last post was actually fixed in Keras, and I was only seeing it because I had a somewhat outdated version of Keras3 in the system (2023-10-17 instead of 2023-12-31.) There was another issue with mixed precision which only affected testing (I had it fixed in the GPT-2 pathway, it may affect other models.) Saving was also broken. Here's a patch that fixes everything I found, except loss being multiplied by #GPUs: |
You should open a PR with the patch! 🤗 (linking this issue) |
Will do once I'm satisfied that I've resolved all the issues. |
Hi @ekuznetsov139, thanks for the investigation here - this looks really good! Just to give you some context, the reason the errors change in the latest
I think the plan from here is that in our TensorFlow code, we're going to completely remove all direct imports of Our primary goal is to ensure that Keras 3 doesn't break backward compatibility for TF code, even if we don't fully support other frameworks with Keras 3. Once backward compatibility is secure, we have plans to fully support Keras 3, which will probably require a community push to make full Keras ports of all of our models that don't use any TensorFlow ops - there's a partial PR at #26224 but it's on hold because of the number of other backward compatibility issues that need to be resolved first. |
Hi @ekuznetsov139 I also meet the same problems when I used tensorflow & keras 3 to load transformers models. Do you fix it? |
Hi @lingluodlut @ekuznetsov139, I believe this is the last PR we need #28588 Note that we still won't have full Keras 3 support, but at least Transformers will continue working when Keras 3 is installed after this PR is merged. |
I am trying to get transformers working with head-of-trunk tensorflow, which requires keras 3 (I'm using keras-nightly (3.0.3.dev2023123103)), and I'm running into issues that seem to be caused by changes in internal behavior of keras. Neither 4.36.2 nor head-of-trunk transformers work.
My test script is simply:
This works with transformers 4.36.2, tensorflow 2.14, keras 2.14.
With head of trunk TF and 4.36.2, I get:
This is evidently because https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/keras_deps.py#L40 is no longer being called from keras 3.0.x and so https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_tf_utils.py#L1133 returns None.
I can bypass this, but then I run into a new problem:
I did some tracing, and the cause is that, when the code hits https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_tf_utils.py#L2905, tf_model.trainable_weights is empty, so transformers can't load any weights into it. I tried moving the block at lines 2915-2919 above the load call, but it has no effect.
Then I tried head of trunk transformers. It fails too, but it fails with different symptoms. First, there is:
The problem is that, at
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L1150, you're calling
self._set_save_spec(self.input_signature)
and hitting https://github.com/keras-team/keras/blob/v3.0.2/keras/backend/tensorflow/layer.py#L16
def _set_save_spec(self, inputs, args=None, kwargs=None)
which is declared with the default parameter 'kwargs=None', but really expects kwargs to be a dict. The logical workaround is
self._set_save_spec(self.input_signature, kwargs={})
This gets me to problem number 2:
This happens because keras has reordered arguments of Layer.add_weight():
https://github.com/keras-team/keras/blob/v2.15.0/keras/engine/base_layer.py#L553
https://github.com/keras-team/keras/blob/v3.0.2/keras/layers/layer.py#L448
so you need to add explicit
name=
in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L3217 and again in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L3220.Unfortunately, even that does not let me load the model, because there's some kind of a glitch that prevents the TF model from correctly setting its weight names, so I get this error:
The text was updated successfully, but these errors were encountered: