Skip to content

Commit

Permalink
Raise relevent err when wrong type is passed in as the accelerator_co…
Browse files Browse the repository at this point in the history
…nfig (#29997)

* Raise relevent err

* Use type instead
  • Loading branch information
muellerzr authored Apr 16, 2024
1 parent 0eaef0c commit e27d930
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,13 @@ def __post_init__(self):
self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None:
Expand Down
29 changes: 29 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3104,6 +3104,35 @@ def test_accelerator_config_from_dict_grad_accum_num_steps(self):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))

def test_accelerator_config_not_instantiated(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(NotImplementedError) as context:
_ = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config=AcceleratorConfig,
)
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))

# Now test with a custom subclass
@dataclasses.dataclass
class CustomAcceleratorConfig(AcceleratorConfig):
pass

@dataclasses.dataclass
class CustomTrainingArguments(TrainingArguments):
accelerator_config: dict = dataclasses.field(
default=CustomAcceleratorConfig,
)

with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(NotImplementedError) as context:
_ = CustomTrainingArguments(
output_dir=tmp_dir,
)
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))


@require_torch
@is_staging_test
Expand Down

0 comments on commit e27d930

Please sign in to comment.