Skip to content

Commit

Permalink
I think working version now, testing
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Apr 12, 2024
1 parent a8e132c commit 5d9a39a
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ class TrainingArguments:
},
)
fsdp_config: Optional[Union[dict, str]] = field(
default_factory=dict,
default=None,
metadata={
"help": (
"Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a "
Expand All @@ -1149,7 +1149,7 @@ class TrainingArguments:
},
)
accelerator_config: Optional[Union[AcceleratorConfig, dict, str]] = field(
default_factory=dict,
default=None,
metadata={
"help": (
"Config to be used with the internal Accelerator object initializtion. The value is either a "
Expand All @@ -1158,7 +1158,7 @@ class TrainingArguments:
},
)
deepspeed: Optional[Union[dict, str]] = field(
default_factory=dict,
default=None,
metadata={
"help": (
"Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already"
Expand Down Expand Up @@ -1262,7 +1262,7 @@ class TrainingArguments:
},
)
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
default_factory=dict,
default=None,
metadata={
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
},
Expand Down Expand Up @@ -1391,10 +1391,15 @@ class TrainingArguments:
def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in VALID_DICT_FIELDS:
passed_value = getattr(self, field)
# We only want to do this if the str starts with a bracket to indiciate a `dict`
# else its likely a filename if supported
if isinstance(getattr(self, field), str) and getattr(self, field).startswith("{"):
setattr(self, field, json.loads(getattr(self, field)))
if isinstance(passed_value, str) and passed_value.startswith("{"):
setattr(self, field, json.loads(passed_value))
# Since we default to a blank dict, set it to `None` after parsing
elif isinstance(passed_value, dict):
if passed_value == {}:
setattr(self, field, None)

# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
Expand Down

0 comments on commit 5d9a39a

Please sign in to comment.