Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for str versions of dicts based on typing #30227

Merged
merged 10 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,37 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"


# Sometimes users will pass in a `str` repr of a dict in the CLI
# We need to track what fields those can be. Each time a new arg
# has a dict type, it must be added to this list.
# Important: These should be typed with Optional[Union[dict,str,...]]
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
_VALID_DICT_FIELDS = [
"accelerator_config",
"fsdp_config",
"deepspeed",
"gradient_checkpointing_kwargs",
"lr_scheduler_kwargs",
]


def _convert_str_dict(passed_value: dict):
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
for key, value in passed_value.items():
if isinstance(value, dict):
passed_value[key] = _convert_str_dict(value)
elif isinstance(value, str):
# First check for bool and convert
if value.lower() in ("true", "false"):
passed_value[key] = value.lower() == "true"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

# Check for digit
elif value.isdigit():
passed_value[key] = int(value)
elif value.replace(".", "", 1).isdigit():
passed_value[key] = float(value)

return passed_value


# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
@dataclass
class TrainingArguments:
Expand Down Expand Up @@ -803,11 +834,11 @@ class TrainingArguments:
default="linear",
metadata={"help": "The scheduler type to use."},
)
lr_scheduler_kwargs: Optional[Dict] = field(
lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
default_factory=dict,
metadata={
"help": (
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts"
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts."
)
},
)
Expand Down Expand Up @@ -1118,7 +1149,6 @@ class TrainingArguments:
)
},
)
# Do not touch this type annotation or it will stop working in CLI
fsdp_config: Optional[Union[dict, str]] = field(
default=None,
metadata={
Expand All @@ -1137,8 +1167,7 @@ class TrainingArguments:
)
},
)
# Do not touch this type annotation or it will stop working in CLI
accelerator_config: Optional[str] = field(
accelerator_config: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": (
Expand All @@ -1147,8 +1176,7 @@ class TrainingArguments:
)
},
)
# Do not touch this type annotation or it will stop working in CLI
deepspeed: Optional[str] = field(
deepspeed: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": (
Expand Down Expand Up @@ -1252,7 +1280,7 @@ class TrainingArguments:
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
gradient_checkpointing_kwargs: Optional[dict] = field(
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
default=None,
metadata={
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
Expand Down Expand Up @@ -1380,6 +1408,17 @@ class TrainingArguments:
)

def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
passed_value = getattr(self, field)
# We only want to do this if the str starts with a bracket to indiciate a `dict`
# else its likely a filename if supported
if isinstance(passed_value, str) and passed_value.startswith("{"):
loaded_dict = json.loads(passed_value)
# Convert str values to types if applicable
loaded_dict = _convert_str_dict(loaded_dict)
setattr(self, field, loaded_dict)

# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
# see https://github.com/huggingface/transformers/issues/10628
Expand Down
69 changes: 68 additions & 1 deletion tests/utils/test_hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import List, Literal, Optional
from typing import Dict, List, Literal, Optional, Union, get_args, get_origin

import yaml

from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import make_choice_type_function, string_to_bool
from transformers.testing_utils import require_torch
from transformers.training_args import _VALID_DICT_FIELDS


# Since Python 3.10, we can use the builtin `|` operator for Union types
Expand Down Expand Up @@ -405,3 +407,68 @@ def test_parse_yaml(self):
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)

def test_valid_dict_annotation(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know on offline discussion I said we probably don't need a test to check the parsing. Seeing the implementation, i.e. TrainingArguments can be created with a field as a string representation of a dict, and not having to include the CLI, I think we can add a simple for at least one of the fields in VALID_DICT_FIELDS

e.g. something along the lines of:

def test_valid_dict_input_parsing(self):
    args = TrainingArguments(
        field_name='{"key": value}'
    )
    # Or however it's assigned in the args
    self.assertEqual(args.field_name, {key: value})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. This also made me notice that we cast int and bools as str still, so added a helper for this (and does so to avoid literal_eval, which can be exploited)

"""
Tests to make sure that `dict` based annotations
are correctly made in the `TrainingArguments`.

If this fails, a type annotation change is
needed on a new input
"""
base_list = _VALID_DICT_FIELDS.copy()
args = TrainingArguments

# First find any annotations that contain `dict`
fields = args.__dataclass_fields__

raw_dict_fields = []
optional_dict_fields = []

for field in fields.values():
# First verify raw dict
if field.type in (dict, Dict):
raw_dict_fields.append(field)
# Next check for `Union` or `Optional`
elif get_origin(field.type) == Union:
if any(arg in (dict, Dict) for arg in get_args(field.type)):
optional_dict_fields.append(field)

# First check: anything in `raw_dict_fields` is very bad
self.assertEqual(
len(raw_dict_fields),
0,
"Found invalid raw `dict` types in the `TrainingArgument` typings. "
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
)

# Next check raw annotations
for field in optional_dict_fields:
args = get_args(field.type)
# These should be returned as `dict`, `str`, ...
# we only care about the first two
self.assertIn(args[0], (Dict, dict))
self.assertEqual(
str(args[1]),
"<class 'str'>",
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, "
"but `str` not found. Please fix this.",
)

# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
for field in optional_dict_fields:
self.assertIn(
field.name,
base_list,
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
)

@require_torch
def test_valid_dict_input_parsing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
)
self.assertEqual(args.accelerator_config.split_batches, True)
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Loading