diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c4577d14f6d89b..e203d6a3ff642b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -175,7 +175,8 @@ class OptimizerNames(ExplicitEnum): # Sometimes users will pass in a `str` repr of a dict in the CLI # We need to track what fields those can be. Each time a new arg -# has a dict type, it must be added to this list +# has a dict type, it must be added to this list. +# Important: These should be typed with Optional[Union[dict,str,...]] VALID_DICT_FIELDS = [ "accelerator_config", "fsdp_config", @@ -1148,7 +1149,7 @@ class TrainingArguments: ) }, ) - accelerator_config: Optional[Union[AcceleratorConfig, dict, str]] = field( + accelerator_config: Optional[Union[dict, str, AcceleratorConfig]] = field( default=None, metadata={ "help": (