Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Seq2seqTrainerTester::test_bad_generation_config_fail_early #31866

Closed
wants to merge 95 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
9abf017
test
ydshieh Jul 9, 2024
d808d48
[test_all] check
ydshieh Jul 9, 2024
20356e4
test
ydshieh Jul 9, 2024
0ea8423
[test_all] check
ydshieh Jul 9, 2024
d36b301
test
ydshieh Jul 9, 2024
4960219
[test_all] check
ydshieh Jul 9, 2024
4e7ba85
[test_all] check
ydshieh Jul 9, 2024
1d7c622
test
ydshieh Jul 9, 2024
73867dc
[test_all] check
ydshieh Jul 9, 2024
a8505ca
test
ydshieh Jul 9, 2024
7cdddbe
[test_all] check
ydshieh Jul 9, 2024
a1a4b7b
test
ydshieh Jul 9, 2024
3c694a7
[test_all] check
ydshieh Jul 9, 2024
db4a053
test
ydshieh Jul 9, 2024
37b72ef
[test_all] check
ydshieh Jul 9, 2024
84f394f
test
ydshieh Jul 9, 2024
d5261bc
[test_all] check
ydshieh Jul 9, 2024
f7a7dab
test
ydshieh Jul 9, 2024
f585258
[test_all] check
ydshieh Jul 9, 2024
0225fe1
test
ydshieh Jul 9, 2024
84a1eb8
[test_all] check
ydshieh Jul 9, 2024
d81e0f4
test
ydshieh Jul 9, 2024
b72d146
[test_all] check
ydshieh Jul 9, 2024
45f0897
fix
ydshieh Jul 10, 2024
c564165
[test_all] check
ydshieh Jul 10, 2024
224abe7
fix
ydshieh Jul 10, 2024
9dae6b9
[test_all] check
ydshieh Jul 10, 2024
a24d5ff
fix
ydshieh Jul 10, 2024
e539ee2
[test_all] check
ydshieh Jul 10, 2024
08732e9
fix
ydshieh Jul 10, 2024
de711fd
[test_all] check
ydshieh Jul 10, 2024
dc20fc7
fix
ydshieh Jul 10, 2024
7ebb9ca
fix
ydshieh Jul 10, 2024
ca5e4d0
[test_all] check
ydshieh Jul 10, 2024
c35aff9
fix
ydshieh Jul 10, 2024
933bf12
[test_all] check
ydshieh Jul 10, 2024
d5e5610
fix
ydshieh Jul 10, 2024
cc7f414
fix
ydshieh Jul 10, 2024
ab1a5d3
fix
ydshieh Jul 10, 2024
03c39ce
fix
ydshieh Jul 10, 2024
749b660
[test_all] check
ydshieh Jul 10, 2024
5d87dfb
fix
ydshieh Jul 10, 2024
6fff0e6
[test_all] check
ydshieh Jul 10, 2024
a481c6b
fix
ydshieh Jul 10, 2024
43b7492
fix
ydshieh Jul 10, 2024
3bc60cc
fix
ydshieh Jul 10, 2024
a5e9bde
[test_all] check
ydshieh Jul 10, 2024
f243cce
fix
ydshieh Jul 10, 2024
4b0f59b
[test_all] check
ydshieh Jul 10, 2024
0c2f507
fix
ydshieh Jul 10, 2024
d4508b9
[test_all] check
ydshieh Jul 10, 2024
61c1d2f
fix
ydshieh Jul 10, 2024
e5d63e9
[test_all] check
ydshieh Jul 10, 2024
b5159c8
[test_all] check
ydshieh Jul 10, 2024
d4b8706
fix
ydshieh Jul 10, 2024
d791e93
fix
ydshieh Jul 10, 2024
679dd45
[test_all] check
ydshieh Jul 10, 2024
92f9ea4
fix
ydshieh Jul 10, 2024
6a3e3e0
fix
ydshieh Jul 10, 2024
c9ab910
fix
ydshieh Jul 10, 2024
aa04cda
[test_all] check
ydshieh Jul 10, 2024
1733485
fix
ydshieh Jul 10, 2024
98b21be
[test_all] check
ydshieh Jul 10, 2024
b98bf06
fix
ydshieh Jul 10, 2024
faa47ae
fix
ydshieh Jul 10, 2024
120938a
[test_all] check
ydshieh Jul 10, 2024
727e673
fix
ydshieh Jul 10, 2024
7ce8dca
[test_all] check
ydshieh Jul 10, 2024
61bf368
fix
ydshieh Jul 10, 2024
773bb24
[test_all] check
ydshieh Jul 10, 2024
a91e8f6
fix
ydshieh Jul 10, 2024
0c99218
fix
ydshieh Jul 10, 2024
01ca7f8
[test_all] check
ydshieh Jul 10, 2024
37aa99d
fix
ydshieh Jul 10, 2024
ca2d28d
[test_all] check
ydshieh Jul 10, 2024
6f6f3e9
fix
ydshieh Jul 10, 2024
4086c15
fix
ydshieh Jul 10, 2024
26777f4
fix
ydshieh Jul 10, 2024
5903ea9
fix
ydshieh Jul 10, 2024
026286f
fix
ydshieh Jul 10, 2024
4b6ae84
fix
ydshieh Jul 10, 2024
b1f986d
fix
ydshieh Jul 10, 2024
f3777e1
fix
ydshieh Jul 10, 2024
58d7f12
[test_all] check
ydshieh Jul 10, 2024
08ed067
fix
ydshieh Jul 10, 2024
ca45dcc
fix
ydshieh Jul 10, 2024
9eb5399
fix
ydshieh Jul 10, 2024
1f5199a
[test_all] check
ydshieh Jul 10, 2024
90cff9e
fix
ydshieh Jul 10, 2024
525be71
[test_all] check
ydshieh Jul 10, 2024
a9eef7b
fix
ydshieh Jul 10, 2024
8123b27
[test_all] check
ydshieh Jul 10, 2024
6fb41da
fix
ydshieh Jul 10, 2024
f195c65
fix
ydshieh Jul 10, 2024
868bd3d
fix
ydshieh Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions .circleci/create_circleci_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,34 @@ def to_dict(self):
# junit familiy xunit1 is necessary to support splitting on test name or class name with circleci split
test_command += f"python3 -m pytest -rsfE -p no:warnings -o junit_family=xunit1 --tb=short --junitxml=test-results/junit.xml -n {self.pytest_num_workers} " + " ".join(pytest_flags)

# tests = "tests/benchmark tests/generation tests/models/autoformer/test_modeling_autoformer.py tests/models/big_bird/test_modeling_big_bird.py tests/models/blip/test_modeling_blip.py tests/models/camembert/test_modeling_camembert.py tests/models/clvp/test_modeling_clvp.py tests/models/convnextv2/test_modeling_convnextv2.py tests/models/data2vec/test_modeling_data2vec_vision.py tests/models/deit/test_modeling_deit.py tests/models/dit/test_modeling_dit.py tests/models/efficientnet/test_modeling_efficientnet.py tests/models/esm/test_modeling_esmfold.py tests/models/focalnet/test_modeling_focalnet.py tests/models/git/test_modeling_git.py tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py tests/models/idefics/test_modeling_idefics.py tests/models/jamba/test_modeling_jamba.py tests/models/led/test_modeling_led.py tests/models/llava_next_video/test_modeling_llava_next_video.py tests/models/mamba/test_modeling_mamba.py tests/models/mbart/test_modeling_mbart.py tests/models/mobilebert/test_modeling_mobilebert.py tests/models/mpt/test_modeling_mpt.py tests/models/nllb_moe/test_modeling_nllb_moe.py tests/models/owlv2/test_modeling_owlv2.py tests/models/pegasus_x/test_modeling_pegasus_x.py tests/models/plbart/test_modeling_plbart.py tests/models/qwen2/test_modeling_qwen2.py tests/models/rembert/test_modeling_rembert.py tests/models/rt_detr/test_modeling_rt_detr.py tests/models/segformer/test_modeling_segformer.py tests/models/speech_to_text/test_modeling_speech_to_text.py tests/models/superpoint/test_modeling_superpoint.py tests/models/t5/test_modeling_t5.py tests/models/trocr/test_modeling_trocr.py tests/models/univnet/test_modeling_univnet.py tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py tests/models/vitdet/test_modeling_vitdet.py tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py tests/models/xlm_roberta/test_modeling_xlm_roberta.py tests/models/zoedepth/test_modeling_zoedepth.py tests/test_configuration_common.py tests/test_modeling_tf_common.py tests/trainer"
# tests = tests.split()

if self.parallelism == 1:
if self.tests_to_run is None:
test_command += " << pipeline.parameters.tests_to_run >>"
else:
test_command += " " + " ".join(self.tests_to_run)
else:
# We need explicit list instead of `pipeline.parameters.tests_to_run` (only available at job runtime)
tests = self.tests_to_run

tests = "tests/benchmark tests/generation tests/models/autoformer/test_modeling_autoformer.py tests/models/big_bird/test_modeling_big_bird.py tests/models/blip/test_modeling_blip.py tests/models/camembert/test_modeling_camembert.py tests/models/clvp/test_modeling_clvp.py tests/models/convnextv2/test_modeling_convnextv2.py tests/models/data2vec/test_modeling_data2vec_vision.py tests/models/deit/test_modeling_deit.py tests/models/dit/test_modeling_dit.py tests/models/efficientnet/test_modeling_efficientnet.py tests/models/esm/test_modeling_esmfold.py tests/models/focalnet/test_modeling_focalnet.py tests/models/git/test_modeling_git.py tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py tests/models/idefics/test_modeling_idefics.py tests/models/jamba/test_modeling_jamba.py tests/models/led/test_modeling_led.py tests/models/llava_next_video/test_modeling_llava_next_video.py tests/models/mamba/test_modeling_mamba.py tests/models/mbart/test_modeling_mbart.py tests/models/mobilebert/test_modeling_mobilebert.py tests/models/mpt/test_modeling_mpt.py tests/models/nllb_moe/test_modeling_nllb_moe.py tests/models/owlv2/test_modeling_owlv2.py tests/models/pegasus_x/test_modeling_pegasus_x.py tests/models/plbart/test_modeling_plbart.py tests/models/qwen2/test_modeling_qwen2.py tests/models/rembert/test_modeling_rembert.py tests/models/rt_detr/test_modeling_rt_detr.py tests/models/segformer/test_modeling_segformer.py tests/models/speech_to_text/test_modeling_speech_to_text.py tests/models/superpoint/test_modeling_superpoint.py tests/models/t5/test_modeling_t5.py tests/models/trocr/test_modeling_trocr.py tests/models/univnet/test_modeling_univnet.py tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py tests/models/vitdet/test_modeling_vitdet.py tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py tests/models/xlm_roberta/test_modeling_xlm_roberta.py tests/models/zoedepth/test_modeling_zoedepth.py tests/test_configuration_common.py tests/test_modeling_tf_common.py tests/trainer"
tests = "tests/benchmark tests/models/blip/test_modeling_blip.py tests/models/data2vec/test_modeling_data2vec_vision.py tests/models/esm/test_modeling_esmfold.py tests/models/idefics/test_modeling_idefics.py tests/models/mamba/test_modeling_mamba.py tests/models/nllb_moe/test_modeling_nllb_moe.py tests/models/qwen2/test_modeling_qwen2.py tests/models/speech_to_text/test_modeling_speech_to_text.py tests/models/univnet/test_modeling_univnet.py tests/models/xlm_roberta/test_modeling_xlm_roberta.py tests/trainer"
# tests = "tests/benchmark tests/models/speech_to_text/test_modeling_speech_to_text.py tests/models/univnet/test_modeling_univnet.py tests/models/xlm_roberta/test_modeling_xlm_roberta.py tests/trainer"
tests = "tests/benchmark tests/models/speech_to_text/test_modeling_speech_to_text.py tests/models/univnet/test_modeling_univnet.py tests/models/xlm_roberta/test_modeling_xlm_roberta.py tests/trainer"
#tests = "tests/benchmark tests/models/speech_to_text/test_modeling_speech_to_text.py tests/trainer"
tests = "tests/benchmark tests/models/univnet/test_modeling_univnet.py tests/models/xlm_roberta/test_modeling_xlm_roberta.py tests/trainer"

tests = "tests/models/speech_to_text/test_modeling_speech_to_text.py tests/models/univnet/test_modeling_univnet.py tests/models/xlm_roberta/test_modeling_xlm_roberta.py"
tests = "tests/benchmark tests/models/qwen2/test_modeling_qwen2.py tests/trainer"

tests = "tests/trainer"

tests = tests.split()

self.pytest_num_workers = 1

# tests = self.tests_to_run
if tests is None:
folder = os.environ["test_preparation_dir"]
test_file = os.path.join(folder, "filtered_test_list.txt")
Expand Down Expand Up @@ -170,6 +190,7 @@ def to_dict(self):

# Each executor to run ~10 tests
n_executors = max(len(expanded_tests) // 10, 1)
n_executors = 1
# Avoid empty test list on some executor(s) or launching too many executors
if n_executors > self.parallelism:
n_executors = self.parallelism
Expand All @@ -188,7 +209,7 @@ def to_dict(self):
test_command = ""
if self.command_timeout:
test_command = f"timeout {self.command_timeout} "
test_command += f"python3 -m pytest -rsfE -p no:warnings --tb=short -o junit_family=xunit1 --junitxml=test-results/junit.xml -n {self.pytest_num_workers} " + " ".join(pytest_flags)
test_command += f"python3 -m pytest -k '(TrainerCallbackTest and (test_add_remove_callback or test_0_event_flow or test_init_callback or test_missing_stateful_callback)) or Seq2seqTrainerTester' -rsfE -p no:warnings --tb=short -o junit_family=xunit1 --junitxml=test-results/junit.xml -n {self.pytest_num_workers} " + " ".join(pytest_flags)
test_command += " $(cat splitted_tests.txt)"
if self.marker is not None:
test_command += f" -m {self.marker}"
Expand Down Expand Up @@ -421,27 +442,14 @@ def job_name(self):
)

REGULAR_TESTS = [
torch_and_tf_job,
torch_and_flax_job,
torch_job,
tf_job,
flax_job,
custom_tokenizers_job,
hub_job,
onnx_job,
exotic_models_job,
tokenization_job
]
EXAMPLES_TESTS = [
examples_torch_job,
examples_tensorflow_job,
]
PIPELINE_TESTS = [
pipelines_torch_job,
pipelines_tf_job,
]
REPO_UTIL_TESTS = [repo_utils_job]
DOC_TESTS = [doc_test_job]
REPO_UTIL_TESTS = []
DOC_TESTS = []


def create_circleci_config(folder=None):
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,13 @@ def validate(self, is_init=False):
UserWarning,
)
if self.top_p is not None and self.top_p != 1.0:
# raise ValueError("bad bad")
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
else:
pass
if self.min_p is not None:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
Expand Down
37 changes: 27 additions & 10 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,33 @@ def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> Gene

# Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws
# an exception if there are warnings at validation time.
try:
with warnings.catch_warnings(record=True) as caught_warnings:
gen_config.validate()
if len(caught_warnings) > 0:
raise ValueError(str([w.message for w in caught_warnings]))
except ValueError as exc:
raise ValueError(
"The loaded generation config instance is invalid -- `GenerationConfig.validate()` throws warnings "
"and/or exceptions. Fix these issues to train your model.\n\nThrown during validation:\n" + str(exc)
)

from transformers.utils.logging import _get_library_root_logger, get_logger, captureWarnings
# captureWarnings(False)

assert _get_library_root_logger().level == 30
logger.setLevel(30)
assert logger.level == 30
assert str(logger) == "<Logger transformers.trainer_seq2seq (WARNING)>"
# assert get_logger("py.warnings").level == 30
assert len(_get_library_root_logger().handlers) == 1
assert len(logger.handlers) == 0
assert _get_library_root_logger().handlers[0].level == 0

with warnings.catch_warnings(record=True) as caught_warnings:
gen_config.validate()

assert _get_library_root_logger().level == 30
logger.setLevel(30)
assert logger.level == 30
assert str(logger) == "<Logger transformers.trainer_seq2seq (WARNING)>"
# assert get_logger("py.warnings").level == 30

if len(caught_warnings) == 0:
# assert len(get_logger("py.warnings").handlers) == 0
logger.warning(f'{get_logger("py.warnings").handlers}')
raise ValueError("not captured")

return gen_config

def evaluate(
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,11 @@ def test_add_remove_callback(self):
expected_callbacks.insert(0, DefaultFlowCallback)
self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

def test_event_flow(self):
def test_0_event_flow(self):
import warnings

# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
# with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=UserWarning)

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
Expand Down
20 changes: 10 additions & 10 deletions tests/trainer/test_trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
T5Tokenizer,
logging,
)
from transformers.testing_utils import TestCasePlus, require_sentencepiece, require_torch, slow
from transformers.testing_utils import LoggingLevel, TestCasePlus, is_flaky, require_sentencepiece, require_torch, slow
from transformers.utils import is_datasets_available


Expand Down Expand Up @@ -195,12 +196,11 @@ def test_bad_generation_config_fail_early(self):
training_args = Seq2SeqTrainingArguments(
".", predict_with_generate=True, generation_config=gen_config, report_to="none"
)
with self.assertRaises(ValueError) as exc:
_ = Seq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=lambda x: {"samples": x[0].shape[0]},
)
self.assertIn("The loaded generation config instance is invalid", str(exc.exception))
_ = Seq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=lambda x: {"samples": x[0].shape[0]},
)

Loading