Skip to content

Commit

Permalink
Allow for str versions of dicts based on typing (#30227)
Browse files Browse the repository at this point in the history
* Bookmark, initial impelemtation. Need to test

* Clean

* Working fully, woop woop

* I think working version now, testing

* Fin!

* rm cast, could keep None

* Fix typing issue

* rm typehint

* Add test

* Add tests and make more rigid
  • Loading branch information
muellerzr authored Apr 16, 2024
1 parent b86d0f4 commit 487505f
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 9 deletions.
55 changes: 47 additions & 8 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,37 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"


# Sometimes users will pass in a `str` repr of a dict in the CLI
# We need to track what fields those can be. Each time a new arg
# has a dict type, it must be added to this list.
# Important: These should be typed with Optional[Union[dict,str,...]]
_VALID_DICT_FIELDS = [
"accelerator_config",
"fsdp_config",
"deepspeed",
"gradient_checkpointing_kwargs",
"lr_scheduler_kwargs",
]


def _convert_str_dict(passed_value: dict):
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
for key, value in passed_value.items():
if isinstance(value, dict):
passed_value[key] = _convert_str_dict(value)
elif isinstance(value, str):
# First check for bool and convert
if value.lower() in ("true", "false"):
passed_value[key] = value.lower() == "true"
# Check for digit
elif value.isdigit():
passed_value[key] = int(value)
elif value.replace(".", "", 1).isdigit():
passed_value[key] = float(value)

return passed_value


# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
@dataclass
class TrainingArguments:
Expand Down Expand Up @@ -803,11 +834,11 @@ class TrainingArguments:
default="linear",
metadata={"help": "The scheduler type to use."},
)
lr_scheduler_kwargs: Optional[Dict] = field(
lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
default_factory=dict,
metadata={
"help": (
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts"
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts."
)
},
)
Expand Down Expand Up @@ -1118,7 +1149,6 @@ class TrainingArguments:
)
},
)
# Do not touch this type annotation or it will stop working in CLI
fsdp_config: Optional[Union[dict, str]] = field(
default=None,
metadata={
Expand All @@ -1137,8 +1167,7 @@ class TrainingArguments:
)
},
)
# Do not touch this type annotation or it will stop working in CLI
accelerator_config: Optional[str] = field(
accelerator_config: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": (
Expand All @@ -1147,8 +1176,7 @@ class TrainingArguments:
)
},
)
# Do not touch this type annotation or it will stop working in CLI
deepspeed: Optional[str] = field(
deepspeed: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": (
Expand Down Expand Up @@ -1252,7 +1280,7 @@ class TrainingArguments:
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
gradient_checkpointing_kwargs: Optional[dict] = field(
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
Expand Down Expand Up @@ -1380,6 +1408,17 @@ class TrainingArguments:
)

def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
passed_value = getattr(self, field)
# We only want to do this if the str starts with a bracket to indiciate a `dict`
# else its likely a filename if supported
if isinstance(passed_value, str) and passed_value.startswith("{"):
loaded_dict = json.loads(passed_value)
# Convert str values to types if applicable
loaded_dict = _convert_str_dict(loaded_dict)
setattr(self, field, loaded_dict)

# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
# see https://github.com/huggingface/transformers/issues/10628
Expand Down
69 changes: 68 additions & 1 deletion tests/utils/test_hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import List, Literal, Optional
from typing import Dict, List, Literal, Optional, Union, get_args, get_origin

import yaml

from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import make_choice_type_function, string_to_bool
from transformers.testing_utils import require_torch
from transformers.training_args import _VALID_DICT_FIELDS


# Since Python 3.10, we can use the builtin `|` operator for Union types
Expand Down Expand Up @@ -405,3 +407,68 @@ def test_parse_yaml(self):
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)

def test_valid_dict_annotation(self):
"""
Tests to make sure that `dict` based annotations
are correctly made in the `TrainingArguments`.
If this fails, a type annotation change is
needed on a new input
"""
base_list = _VALID_DICT_FIELDS.copy()
args = TrainingArguments

# First find any annotations that contain `dict`
fields = args.__dataclass_fields__

raw_dict_fields = []
optional_dict_fields = []

for field in fields.values():
# First verify raw dict
if field.type in (dict, Dict):
raw_dict_fields.append(field)
# Next check for `Union` or `Optional`
elif get_origin(field.type) == Union:
if any(arg in (dict, Dict) for arg in get_args(field.type)):
optional_dict_fields.append(field)

# First check: anything in `raw_dict_fields` is very bad
self.assertEqual(
len(raw_dict_fields),
0,
"Found invalid raw `dict` types in the `TrainingArgument` typings. "
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
)

# Next check raw annotations
for field in optional_dict_fields:
args = get_args(field.type)
# These should be returned as `dict`, `str`, ...
# we only care about the first two
self.assertIn(args[0], (Dict, dict))
self.assertEqual(
str(args[1]),
"<class 'str'>",
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, "
"but `str` not found. Please fix this.",
)

# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
for field in optional_dict_fields:
self.assertIn(
field.name,
base_list,
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
)

@require_torch
def test_valid_dict_input_parsing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
)
self.assertEqual(args.accelerator_config.split_batches, True)
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)

0 comments on commit 487505f

Please sign in to comment.