diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cdf6325c4b4ae6..e5ac449c6556c7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -173,6 +173,37 @@ class OptimizerNames(ExplicitEnum): GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" +# Sometimes users will pass in a `str` repr of a dict in the CLI +# We need to track what fields those can be. Each time a new arg +# has a dict type, it must be added to this list. +# Important: These should be typed with Optional[Union[dict,str,...]] +_VALID_DICT_FIELDS = [ + "accelerator_config", + "fsdp_config", + "deepspeed", + "gradient_checkpointing_kwargs", + "lr_scheduler_kwargs", +] + + +def _convert_str_dict(passed_value: dict): + "Safely checks that a passed value is a dictionary and converts any string values to their appropriate types." + for key, value in passed_value.items(): + if isinstance(value, dict): + passed_value[key] = _convert_str_dict(value) + elif isinstance(value, str): + # First check for bool and convert + if value.lower() in ("true", "false"): + passed_value[key] = value.lower() == "true" + # Check for digit + elif value.isdigit(): + passed_value[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + passed_value[key] = float(value) + + return passed_value + + # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 @dataclass class TrainingArguments: @@ -803,11 +834,11 @@ class TrainingArguments: default="linear", metadata={"help": "The scheduler type to use."}, ) - lr_scheduler_kwargs: Optional[Dict] = field( + lr_scheduler_kwargs: Optional[Union[dict, str]] = field( default_factory=dict, metadata={ "help": ( - "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts" + "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts." ) }, ) @@ -1118,7 +1149,6 @@ class TrainingArguments: ) }, ) - # Do not touch this type annotation or it will stop working in CLI fsdp_config: Optional[Union[dict, str]] = field( default=None, metadata={ @@ -1137,8 +1167,7 @@ class TrainingArguments: ) }, ) - # Do not touch this type annotation or it will stop working in CLI - accelerator_config: Optional[str] = field( + accelerator_config: Optional[Union[dict, str]] = field( default=None, metadata={ "help": ( @@ -1147,8 +1176,7 @@ class TrainingArguments: ) }, ) - # Do not touch this type annotation or it will stop working in CLI - deepspeed: Optional[str] = field( + deepspeed: Optional[Union[dict, str]] = field( default=None, metadata={ "help": ( @@ -1252,7 +1280,7 @@ class TrainingArguments: "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." }, ) - gradient_checkpointing_kwargs: Optional[dict] = field( + gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field( default=None, metadata={ "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`." @@ -1380,6 +1408,17 @@ class TrainingArguments: ) def __post_init__(self): + # Parse in args that could be `dict` sent in from the CLI as a string + for field in _VALID_DICT_FIELDS: + passed_value = getattr(self, field) + # We only want to do this if the str starts with a bracket to indiciate a `dict` + # else its likely a filename if supported + if isinstance(passed_value, str) and passed_value.startswith("{"): + loaded_dict = json.loads(passed_value) + # Convert str values to types if applicable + loaded_dict = _convert_str_dict(loaded_dict) + setattr(self, field, loaded_dict) + # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home # see https://github.com/huggingface/transformers/issues/10628 diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index c0fa748cbfa439..87d1858cc6b7a7 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -22,12 +22,14 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import List, Literal, Optional +from typing import Dict, List, Literal, Optional, Union, get_args, get_origin import yaml from transformers import HfArgumentParser, TrainingArguments from transformers.hf_argparser import make_choice_type_function, string_to_bool +from transformers.testing_utils import require_torch +from transformers.training_args import _VALID_DICT_FIELDS # Since Python 3.10, we can use the builtin `|` operator for Union types @@ -405,3 +407,68 @@ def test_parse_yaml(self): def test_integration_training_args(self): parser = HfArgumentParser(TrainingArguments) self.assertIsNotNone(parser) + + def test_valid_dict_annotation(self): + """ + Tests to make sure that `dict` based annotations + are correctly made in the `TrainingArguments`. + + If this fails, a type annotation change is + needed on a new input + """ + base_list = _VALID_DICT_FIELDS.copy() + args = TrainingArguments + + # First find any annotations that contain `dict` + fields = args.__dataclass_fields__ + + raw_dict_fields = [] + optional_dict_fields = [] + + for field in fields.values(): + # First verify raw dict + if field.type in (dict, Dict): + raw_dict_fields.append(field) + # Next check for `Union` or `Optional` + elif get_origin(field.type) == Union: + if any(arg in (dict, Dict) for arg in get_args(field.type)): + optional_dict_fields.append(field) + + # First check: anything in `raw_dict_fields` is very bad + self.assertEqual( + len(raw_dict_fields), + 0, + "Found invalid raw `dict` types in the `TrainingArgument` typings. " + "This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`", + ) + + # Next check raw annotations + for field in optional_dict_fields: + args = get_args(field.type) + # These should be returned as `dict`, `str`, ... + # we only care about the first two + self.assertIn(args[0], (Dict, dict)) + self.assertEqual( + str(args[1]), + "", + f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, " + "but `str` not found. Please fix this.", + ) + + # Second check: anything in `optional_dict_fields` is bad if it's not in `base_list` + for field in optional_dict_fields: + self.assertIn( + field.name, + base_list, + f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`", + ) + + @require_torch + def test_valid_dict_input_parsing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + args = TrainingArguments( + output_dir=tmp_dir, + accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}', + ) + self.assertEqual(args.accelerator_config.split_batches, True) + self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)