diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e5ac449c6556c7..8c226459fc3db5 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1815,6 +1815,13 @@ def __post_init__(self): self.accelerator_config = AcceleratorConfig() elif isinstance(self.accelerator_config, dict): self.accelerator_config = AcceleratorConfig(**self.accelerator_config) + # Check that a user didn't pass in the class instantiator + # such as `accelerator_config = AcceleratorConfig` + elif isinstance(self.accelerator_config, type): + raise NotImplementedError( + "Tried passing in a callable to `accelerator_config`, but this is not supported. " + "Please pass in a fully constructed `AcceleratorConfig` object instead." + ) else: self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) if self.dispatch_batches is not None: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 483022a09e6fab..6eede8b447cdcd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3104,6 +3104,35 @@ def test_accelerator_config_from_dict_grad_accum_num_steps(self): trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception)) + def test_accelerator_config_not_instantiated(self): + # Checks that accelerator kwargs can be passed through + # and the accelerator is initialized respectively + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(NotImplementedError) as context: + _ = RegressionTrainingArguments( + output_dir=tmp_dir, + accelerator_config=AcceleratorConfig, + ) + self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception)) + + # Now test with a custom subclass + @dataclasses.dataclass + class CustomAcceleratorConfig(AcceleratorConfig): + pass + + @dataclasses.dataclass + class CustomTrainingArguments(TrainingArguments): + accelerator_config: dict = dataclasses.field( + default=CustomAcceleratorConfig, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(NotImplementedError) as context: + _ = CustomTrainingArguments( + output_dir=tmp_dir, + ) + self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception)) + @require_torch @is_staging_test