Skip to content

Commit

Permalink
Fix flaky test_beam_search_low_memory (#35611)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Jan 10, 2025
1 parent b02828e commit 04eae98
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def _get_logits_processor_kwargs(self, do_sample=False, config=None):
"vision_start_token_id",
]:
token_index = getattr(config, key, None)
if token_index is None and hasattr(self, "model_tester"):
token_index = getattr(self.model_tester, key, None)
if token_index is not None and token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([token_index])

Expand Down Expand Up @@ -1077,14 +1079,20 @@ def test_beam_search_low_memory(self):
):
self.skipTest(reason="May fix in the future: need model-specific fixes")

set_model_tester_for_less_flaky_test(self)

config, inputs_dict = self.prepare_config_and_inputs_for_generate()
set_config_for_less_flaky_test(config)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output

config.use_cache = True
config.is_decoder = True

# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
set_model_for_less_flaky_test(model)

logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)

low_output = model.generate(
**inputs_dict,
Expand All @@ -1093,6 +1101,10 @@ def test_beam_search_low_memory(self):
early_stopping=True,
low_memory=True,
use_cache=True,
output_scores=True,
output_logits=True,
return_dict_in_generate=True,
**logits_processor_kwargs,
)

high_output = model.generate(
Expand All @@ -1102,8 +1114,13 @@ def test_beam_search_low_memory(self):
early_stopping=True,
low_memory=False,
use_cache=True,
output_scores=True,
output_logits=True,
return_dict_in_generate=True,
**logits_processor_kwargs,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(low_output, high_output)

@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
Expand Down

0 comments on commit 04eae98

Please sign in to comment.