Skip to content

Commit

Permalink
Make test_generate_with_static_cache even less flaky (#34995)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Dec 20, 2024
1 parent 0fc2970 commit 504c4d3
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 32 deletions.
48 changes: 48 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import collections
import contextlib
import copy
import doctest
import functools
import gc
Expand Down Expand Up @@ -1396,6 +1397,53 @@ def assert_screenout(out, what):
assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"


def set_model_tester_for_less_flaky_test(test_case):
if hasattr(test_case.model_tester, "num_hidden_layers"):
test_case.model_tester.num_hidden_layers = 1
if (
hasattr(test_case.model_tester, "vision_config")
and "num_hidden_layers" in test_case.model_tester.vision_config
):
test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config)
test_case.model_tester.vision_config["num_hidden_layers"] = 1
if hasattr(test_case.model_tester, "text_config") and "num_hidden_layers" in test_case.model_tester.text_config:
test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config)
test_case.model_tester.text_config["num_hidden_layers"] = 1


def set_config_for_less_flaky_test(config):
target_attrs = [
"rms_norm_eps",
"layer_norm_eps",
"norm_eps",
"norm_epsilon",
"layer_norm_epsilon",
"batch_norm_eps",
]
for target_attr in target_attrs:
setattr(config, target_attr, 1.0)

# norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
# (We don't need the original epsilon values to check eager/sdpa matches)
attrs = ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]
for attr in attrs:
if hasattr(config, attr):
for target_attr in target_attrs:
setattr(getattr(config, attr), target_attr, 1.0)


def set_model_for_less_flaky_test(model):
# Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
target_names = ("LayerNorm", "GroupNorm", "BatchNorm", "RMSNorm", "BatchNorm2d", "BatchNorm1d")
target_attrs = ["eps", "epsilon", "variance_epsilon"]
if is_torch_available() and isinstance(model, torch.nn.Module):
for module in model.modules():
if type(module).__name__.endswith(target_names):
for attr in target_attrs:
if hasattr(module, attr):
setattr(module, attr, 1.0)


class CaptureStd:
"""
Context manager to capture:
Expand Down
7 changes: 7 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
set_config_for_less_flaky_test,
set_model_for_less_flaky_test,
set_model_tester_for_less_flaky_test,
slow,
torch_device,
)
Expand Down Expand Up @@ -1921,11 +1924,13 @@ def test_generate_with_static_cache(self):
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
has the expected shapes
"""
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()
set_config_for_less_flaky_test(config)
main_input = inputs_dict[model_class.main_input_name]

if config.is_encoder_decoder:
Expand All @@ -1938,6 +1943,8 @@ def test_generate_with_static_cache(self):

for dtype in (torch.float32, torch.float16):
model = model_class(config).to(torch_device).to(dtype).eval()
set_model_for_less_flaky_test(model)

generation_kwargs = {
"max_new_tokens": max_new_tokens,
"return_dict_in_generate": True, # Required to return `past_key_values`
Expand Down
15 changes: 15 additions & 0 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
require_torch_gpu,
require_torch_sdpa,
require_torchaudio,
set_config_for_less_flaky_test,
set_model_for_less_flaky_test,
set_model_tester_for_less_flaky_test,
slow,
torch_device,
)
Expand Down Expand Up @@ -516,8 +519,11 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"

set_model_tester_for_less_flaky_test(self)

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
model = model_class(config)

is_encoder_decoder = model.config.is_encoder_decoder
Expand All @@ -534,6 +540,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
)
model_eager = model_eager.eval().to(torch_device)

set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)

# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
Expand Down Expand Up @@ -1528,8 +1537,11 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"

set_model_tester_for_less_flaky_test(self)

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
model = model_class(config)

is_encoder_decoder = model.config.is_encoder_decoder
Expand All @@ -1546,6 +1558,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
)
model_eager = model_eager.eval().to(torch_device)

set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)

# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
Expand Down
16 changes: 16 additions & 0 deletions tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,13 @@ def test_generation_languages(self):
def test_speech_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input()

from transformers.testing_utils import set_config_for_less_flaky_test, set_model_for_less_flaky_test

set_config_for_less_flaky_test(config)

model = SeamlessM4Tv2Model(config=config)
set_model_for_less_flaky_test(model)

self.update_generation(model)
model.save_pretrained(self.tmpdirname)
model.to(torch_device)
Expand All @@ -852,13 +858,23 @@ def test_speech_generation(self):
state_dict = model.state_dict()

text_model = SeamlessM4Tv2ForTextToSpeech.from_pretrained(self.tmpdirname)
# Even if this component is loaded after `model.save_pretrained` which is after
# `set_model_for_less_flaky_test(model)`, we still need to apply `set_model_for_less_flaky_test` here as the
# `eps` attribute in the model's norm layers is not set from the config.
set_model_for_less_flaky_test(text_model)

self.update_generation(text_model)
text_model.to(torch_device)
text_model.eval()

output_text = self.factory_generation_speech_test(model, input_text)

speech_model = SeamlessM4Tv2ForSpeechToSpeech.from_pretrained(self.tmpdirname)
# Even if this component is loaded after `model.save_pretrained` which is after
# `set_model_for_less_flaky_test(model)`, we still need to apply `set_model_for_less_flaky_test` here as the
# `eps` attribute in the model's norm layers is not set from the config.
set_model_for_less_flaky_test(speech_model)

self.update_generation(speech_model)
speech_model.to(torch_device)
speech_model.eval()
Expand Down
39 changes: 7 additions & 32 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
set_config_for_less_flaky_test,
set_model_for_less_flaky_test,
set_model_tester_for_less_flaky_test,
slow,
torch_device,
)
Expand Down Expand Up @@ -3976,34 +3979,11 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"

if hasattr(self.model_tester, "num_hidden_layers"):
self.model_tester.num_hidden_layers = 1
if hasattr(self.model_tester, "vision_config") and "num_hidden_layers" in self.model_tester.vision_config:
self.model_tester.vision_config = copy.deepcopy(self.model_tester.vision_config)
self.model_tester.vision_config["num_hidden_layers"] = 1
if hasattr(self.model_tester, "text_config") and "num_hidden_layers" in self.model_tester.text_config:
self.model_tester.text_config = copy.deepcopy(self.model_tester.text_config)
self.model_tester.text_config["num_hidden_layers"] = 1
set_model_tester_for_less_flaky_test(self)

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

config.rms_norm_eps = 1.0
config.layer_norm_eps = 1.0
config.norm_eps = 1.0
config.norm_epsilon = 1.0
config.layer_norm_epsilon = 1.0

# norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
# (We don't need the original epsilon values to check eager/sdpa matches)
for attr in ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]:
if hasattr(config, attr):
getattr(config, attr).rms_norm_eps = 1.0
getattr(config, attr).layer_norm_eps = 1.0
getattr(config, attr).norm_eps = 1.0
getattr(config, attr).norm_epsilon = 1.0
getattr(config, attr).layer_norm_epsilon = 1.0

set_config_for_less_flaky_test(config)
model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
Expand All @@ -4029,13 +4009,8 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
)
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)

# Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
for x in model_eager.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
for x in model_sdpa.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)

# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
Expand Down

0 comments on commit 504c4d3

Please sign in to comment.