-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proper build() methods for TF #27794
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
6966265
to
0f16397
Compare
and len(summarizer.model.trainable_weights) > 0 | ||
and "GPU" in summarizer.model.trainable_weights[0].device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick note to any reviewers so they can make sense of this one: The problem is that get_gpu_count()
is supposed to be framework-neutral, but actually checks the frameworks in order using is_torch_available()
etc. and returns the GPU count for the first framework that matches.
This is very risky for TensorFlow, because TF environments will often have Torch as well, and if Torch is present then the Torch GPU count is returned instead, which may be 0 even if TF is running on GPU.
I just refactored the check to not use that function.
|
||
# Copied from: | ||
# https://gist.github.com/Rocketknight1/43abbe6e73f1008e6e459486e01e0ceb | ||
class TFAdaptiveAvgPool1D(tf.keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another note for reviewers: TF doesn't have an AdaptivePool layer like Torch does. I realized while updating the class in this PR that we were still using my old TF version of the layer, which is very inefficient. I wrote a much more performant version later, and so I took the opportunity to do the replacement here (it also fixed some naming issues that the old layer had)
oh my god the tests pass i didn't think this was ever going to happen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice and clean 🔥 I can't imagine the time you took to find this "simple" solution 👀
Thanks! There's one thing left to add - some layers are buildable with int shapes in Keras 2, but that always fails in Keras 3. I'm going to do a quick replacement so that those become actual shapes (with extra |
Quick update: The build shapes all have proper ranks instead of just being ints now, but our old method of controlling names with |
Got a solution, but I think it fits better in another PR! I'm gonna merge this one for now and see what shakes out in the nightly CI, while I work on the next phase of Keras 3 compatibility. |
Not sure I get why this was merged? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As promised, a quick follow up review :) Thanks for this massive piece of work! Layer naming / building is definitely easier to follow like this ❤️
Just a few comments, mostly nits. Main ones are about saving the config in all of the layer modules and where the early return happens in some methods.
if self.built: | ||
return | ||
self.built = True | ||
if getattr(self, "summary", None) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the significance of "summary" here?
@@ -146,7 +146,7 @@ def __init__(self, config: AlbertConfig, **kwargs): | |||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") | |||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) | |||
|
|||
def build(self, input_shape: tf.TensorShape): | |||
def build(self, input_shape=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit - shouldn't this become an optional?
def build(self, input_shape=None): | |
def build(self, input_shape: Optional[tf.TensorShape] = None): |
@@ -246,6 +251,7 @@ def __init__(self, config: AlbertConfig, **kwargs): | |||
# Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993 | |||
self.attention_dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) | |||
self.output_dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) | |||
self.config = config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Storing the whole config in each layer can start to make things a lot bigger if it's large e.g. id2label
with many classes. In the case of TF with safetensors, are we just instantiating the class from the library and loading the weight e.g. we don't store all the class attributes when saving?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point, but I think it's okay! Our save_pretrained
methods should just save the weights, and not save attributes of the layers like this. Also, setting self.config = config
just creates a reference to the same underlying config
object, so it doesn't use extra memory when the model is initialized either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, a lot of our models do self.config = config
in the __init__
anyway! I just needed to add a lot more of it so that the build()
methods could see the config vars they need.
@@ -965,10 +1086,18 @@ def __init__(self, config, input_embeddings, **kwargs): | |||
# an output-only bias for each token. | |||
self.decoder = input_embeddings | |||
|
|||
def build(self, input_shape): | |||
def build(self, input_shape=None): | |||
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this go after the check of if self.built
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably! In general, this happens when the class had an existing build method, in which case my new method is appended to the end of it. Existing build methods in the codebase don't always have an if built:
clause. It shouldn't actually cause too many problems, but I could consider a pass to fix it up if it does.
@@ -169,7 +169,12 @@ def build(self, input_shape: tf.TensorShape = None): | |||
name="embeddings", | |||
) | |||
|
|||
super().build(input_shape) | |||
if self.built: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here - can we return quickly without calling self.add_weight
if self.built
is True?
@@ -766,7 +877,21 @@ def build(self, input_shape: tf.TensorShape = None): | |||
name="logit_scale", | |||
) | |||
|
|||
super().build(input_shape) | |||
if self.built: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same q here. I'm not going to message any time I see it in case it's nothing but if it's not then you'll have to go through the PR to find all the cases!
# If a specific input shape is passed in, we need to modify it to account for padding | ||
# Not necessary if those portions of the shape are None | ||
if input_shape[-2] is not None: | ||
input_shape[-2] += self.explicit_padding * 2 | ||
super().build(input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still want this super().build() call here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not, but it should be harmless - I don't even think the super()
method really does anything in TF anymore!
* Add a convenience method for building in your own name scope * Second attempt at auto layer building * Revert "Second attempt at auto layer building" This reverts commit e03a3aa. * Attempt poedator#3 * Revert "Attempt poedator#3" This reverts commit b9df7a0. * Add missing attributes that we're going to need later * Add some attributes we're going to need later * A fourth attempt! Feel the power flow through you! * Revert "A fourth attempt! Feel the power flow through you!" This reverts commit 6bf4aaf. * Add more values we'll need later * TF refactor that we'll need later * Revert "TF refactor that we'll need later" This reverts commit ca07202. * Revert "Revert "TF refactor that we'll need later"" This reverts commit 1beb0f3. * make fixup * Attempt five! * Revert "Attempt five!" This reverts commit 3302207. * Attempt six - this time don't add empty methods * Revert "Attempt six - this time don't add empty methods" This reverts commit 67d6012. * Attempt seven - better base model class detection! * Revert "Attempt seven - better base model class detection!" This reverts commit 5f14845. * Another attribute we'll need later * Try again with the missing attribute! * Revert "Try again with the missing attribute!" This reverts commit 760c6f3. * This is the attempt that will pierce the heavens! * Revert "This is the attempt that will pierce the heavens!" This reverts commit c868bb6. * Attempt seven - snag list is steadily decreasing * Revert "Attempt seven - snag list is steadily decreasing" This reverts commit 46fbd97. * Attempt eight - will an empty snag list do it? * Revert "Attempt eight - will an empty snag list do it?" This reverts commit 7c8a3c2. * Fixes to Hubert issues that cause problems later * Trying again with Conv1D/SeparableConv fixes * Revert "Trying again with Conv1D/SeparableConv fixes" This reverts commit 55092bc. * Apply the build shape fixes to Wav2Vec2 as well * One more attempt! * Revert "One more attempt!" This reverts commit 5ac3e4c. * Another attempt! * Revert "Another attempt!" This reverts commit ea16d89. * Let's see how many failures we get without the internal build method * Fix OpenAI * Fix MobileBERT * (Mostly) fix GroupVIT * Fix BLIP * One more BLIP fix * One more BLIP fix! * Fix Regnet * Finally fully fix GroupViT * Fix Data2Vec and add the new AdaptivePool * Fix Segformer * Fix Albert * Fix Deberta/DebertaV2 * Fix XLM * Actually fix XLM * Fix Flaubert * Fix lxmert * Fix Resnet * Fix ConvBERT * Fix ESM * Fix Convnext / ConvnextV2 * Fix SAM * Fix Efficientformer * Fix LayoutLMv3 * Fix speech_to_text * Fix mpnet and mobilevit * Fix Swin * Fix CTRL * Fix CVT * Fix DPR * Fix Wav2Vec2 * Fix T5 * Fix Hubert * Fix GPT2 * Fix Whisper * Fix DeiT * Fix the encoder-decoder / dual-encoder classes * make fix-copies * build in name scope * Fix summarization test * Fix tied weight names for BART + Blenderbot * Fix tied weight name building * Fix to TFESM weight building * Update TF SAM * Expand all the shapes out into Big Boy Shapes
* Add a convenience method for building in your own name scope * Second attempt at auto layer building * Revert "Second attempt at auto layer building" This reverts commit e03a3aa. * Attempt huggingface#3 * Revert "Attempt huggingface#3" This reverts commit b9df7a0. * Add missing attributes that we're going to need later * Add some attributes we're going to need later * A fourth attempt! Feel the power flow through you! * Revert "A fourth attempt! Feel the power flow through you!" This reverts commit 6bf4aaf. * Add more values we'll need later * TF refactor that we'll need later * Revert "TF refactor that we'll need later" This reverts commit ca07202. * Revert "Revert "TF refactor that we'll need later"" This reverts commit 1beb0f3. * make fixup * Attempt five! * Revert "Attempt five!" This reverts commit 3302207. * Attempt six - this time don't add empty methods * Revert "Attempt six - this time don't add empty methods" This reverts commit 67d6012. * Attempt seven - better base model class detection! * Revert "Attempt seven - better base model class detection!" This reverts commit 5f14845. * Another attribute we'll need later * Try again with the missing attribute! * Revert "Try again with the missing attribute!" This reverts commit 760c6f3. * This is the attempt that will pierce the heavens! * Revert "This is the attempt that will pierce the heavens!" This reverts commit c868bb6. * Attempt seven - snag list is steadily decreasing * Revert "Attempt seven - snag list is steadily decreasing" This reverts commit 46fbd97. * Attempt eight - will an empty snag list do it? * Revert "Attempt eight - will an empty snag list do it?" This reverts commit 7c8a3c2. * Fixes to Hubert issues that cause problems later * Trying again with Conv1D/SeparableConv fixes * Revert "Trying again with Conv1D/SeparableConv fixes" This reverts commit 55092bc. * Apply the build shape fixes to Wav2Vec2 as well * One more attempt! * Revert "One more attempt!" This reverts commit 5ac3e4c. * Another attempt! * Revert "Another attempt!" This reverts commit ea16d89. * Let's see how many failures we get without the internal build method * Fix OpenAI * Fix MobileBERT * (Mostly) fix GroupVIT * Fix BLIP * One more BLIP fix * One more BLIP fix! * Fix Regnet * Finally fully fix GroupViT * Fix Data2Vec and add the new AdaptivePool * Fix Segformer * Fix Albert * Fix Deberta/DebertaV2 * Fix XLM * Actually fix XLM * Fix Flaubert * Fix lxmert * Fix Resnet * Fix ConvBERT * Fix ESM * Fix Convnext / ConvnextV2 * Fix SAM * Fix Efficientformer * Fix LayoutLMv3 * Fix speech_to_text * Fix mpnet and mobilevit * Fix Swin * Fix CTRL * Fix CVT * Fix DPR * Fix Wav2Vec2 * Fix T5 * Fix Hubert * Fix GPT2 * Fix Whisper * Fix DeiT * Fix the encoder-decoder / dual-encoder classes * make fix-copies * build in name scope * Fix summarization test * Fix tied weight names for BART + Blenderbot * Fix tied weight name building * Fix to TFESM weight building * Update TF SAM * Expand all the shapes out into Big Boy Shapes
This fixes security issues #274, #275, #276. Can't upgrade to a higher version because this change seems to break model loading and some layers are failing to load: huggingface/transformers#27794
This fixes security issues #274, #275, #276. Can't upgrade to a higher version because this change seems to break model loading and some layers are failing to load: huggingface/transformers#27794
TensorFlow builds weights lazily. This means that layers do not have an
input_dim
argument and do not create weight tensors in the model__init__()
. Instead, the layers wait until theirbuild()
method is called, which usually happens implicitly the first time the layer receives an input. Layers use the shape of the first input they see, or the value explicitly passed to theirbuild()
method, to infer their input dim and build their weight tensors.Up until now, almost none of our TF models had explicit
build()
methods. This meant that weights were built implicitly when the model was called, which required lots of tiny hacks all over the codebase:from_pretrained()
to prepare the model weights so that we could load a checkpointtf.name_scope()
inside their forward pass (!) to control their weight names, which only worked because the weights were always built thereThis had always been a big chunk of tech debt that I'd wanted to fix, but it was such a large task that I never really had time. However, with Keras 3 approaching, it became quite urgent. I tried getting GPT-4 to figure out the
build()
shapes automatically, but it generally failed, so I had to resort to usingast
and static analysis of the PyTorch and TF modeling files to cross-match layers from TF to PyTorch code, using the input size arguments from PyTorch to automatically create and populate newbuild()
methods, and then did a manual pass afterwards to fix up the remaining issues.As a result, after this PR:
build()
methodsbuild()
hierarchyfrom_pretrained()
! Should make model loading significantly faster, especially on CPU, and should help in the CI.While I was working on this PR, I also encountered some other issues that I fixed in passing:
build_in_name_scope()
method and refactored some tests/methods to use it instead. Calling this method yields the same name hierarchy as implicitly callingbuild()
when doing a forward pass, whereas directly callingmodel.build()
does not (because TF enters a name_scope in__call__()
)TFSequenceSummary
andTFConv1D
classes. These are mostly used by older models.Note to reviewers: Most of this PR was generated automatically, and just consists of walls of new
build()
methods. You can generally trust that these methods are correct so long as the CI is green, so you hopefully don't have to read them all - there's >11,000 lines of them! The main things to review are the changes in core files likemodeling_tf_utils.py
, the newbuild_in_name_scope()
method, etc.