Skip to content

Commit

Permalink
Replace build() with build_in_name_scope() for some TF tests (#28046)
Browse files Browse the repository at this point in the history
Replace build() with build_in_name_scope() for some tests
  • Loading branch information
Rocketknight1 authored Dec 14, 2023
1 parent 050e0b4 commit 3060899
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tests/models/bart/test_modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def test_save_load_after_resize_token_embeddings(self):
old_total_size = config.vocab_size
new_total_size = old_total_size + new_tokens_size
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
model.build()
model.build_in_name_scope()
model.resize_token_embeddings(new_total_size)

# fetch the output for an input exclusively made of new members of the vocabulary
Expand Down
2 changes: 1 addition & 1 deletion tests/models/ctrl/test_modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_model_common_attributes(self):

for model_class in self.all_model_classes:
model = model_class(config)
model.build() # may be needed for the get_bias() call below
model.build_in_name_scope() # may be needed for the get_bias() call below
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)

if model_class in list_lm_models:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_onnx_compliancy(self):

with tf.Graph().as_default() as g:
model = model_class(config)
model.build()
model.build_in_name_scope()

for op in g.get_operations():
model_op_names.add(op.node_def.op)
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_onnx_runtime_optimize(self):

for model_class in self.all_model_classes[:2]:
model = model_class(config)
model.build()
model.build_in_name_scope()

onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)

Expand Down Expand Up @@ -1088,7 +1088,7 @@ def test_resize_token_embeddings(self):
def _get_word_embedding_weight(model, embedding_layer):
if isinstance(embedding_layer, tf.keras.layers.Embedding):
# builds the embeddings layer
model.build()
model.build_in_name_scope()
return embedding_layer.embeddings
else:
return model._get_word_embedding_weight(embedding_layer)
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def test_save_load_after_resize_token_embeddings(self):
old_total_size = config.vocab_size
new_total_size = old_total_size + new_tokens_size
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
model.build()
model.build_in_name_scope()
model.resize_token_embeddings(new_total_size)

# fetch the output for an input exclusively made of new members of the vocabulary
Expand Down
8 changes: 4 additions & 4 deletions tests/test_modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,8 @@ def test_checkpoint_sharding_local(self):
# Finally, check the model can be reloaded
new_model = TFBertModel.from_pretrained(tmp_dir)

model.build()
new_model.build()
model.build_in_name_scope()
new_model.build_in_name_scope()

for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
Expand Down Expand Up @@ -632,7 +632,7 @@ def test_push_to_hub(self):
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build()
model.build_in_name_scope()

logging.set_verbosity_info()
logger = logging.get_logger("transformers.utils.hub")
Expand Down Expand Up @@ -701,7 +701,7 @@ def test_push_to_hub_in_organization(self):
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build()
model.build_in_name_scope()

model.push_to_hub("valid_org/test-model-tf-org", token=self._token)

Expand Down

0 comments on commit 3060899

Please sign in to comment.