Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📑 Refactor TrlParser #2412

Merged
merged 17 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/script_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
## TrlParser

[[autodoc]] TrlParser
- parse_args_and_config
- parse_args_into_dataclasses
- set_defaults_with_config
131 changes: 131 additions & 0 deletions tests/test_cli_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from dataclasses import dataclass
from unittest.mock import mock_open, patch

from trl import TrlParser


@dataclass
class MyDataclass:
arg1: int
arg2: str = "default"


@dataclass
class InvalidDataclass:
config: str # This should raise an error in the TrlParser


class TestTrlParser(unittest.TestCase):
def test_init_without_config_field(self):
"""Test initialization without 'config' field in the dataclasses."""
parser = TrlParser(dataclass_types=[MyDataclass])
self.assertIsInstance(parser, TrlParser)

def test_init_with_config_field(self):
"""Test initialization with a 'config' field in the dataclass (should raise ValueError)."""
with self.assertRaises(ValueError) as context:
TrlParser(dataclass_types=[InvalidDataclass])
self.assertTrue("has a field named 'config'" in str(context.exception))

@patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2"))
@patch("yaml.safe_load")
@patch("os.environ", new_callable=dict) # Mock os.environ as a dictionary
def test_parse_args_and_config_with_valid_config(self, mock_environ, mock_yaml_load):
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
"""Test parse_args_and_config method with valid arguments and config."""
mock_yaml_load.return_value = {"env": {"VAR1": "value1", "VAR2": "value2"}, "arg1": 2}

parser = TrlParser(dataclass_types=[MyDataclass])

args = ["--arg2", "value", "--config", "config.yaml"] # don't set arg1 to test default value

# Simulate the config being loaded and environment variables being set
result_args = parser.parse_args_and_config(args)

# Set the environment variables using the mock
mock_environ["VAR1"] = "value1"
mock_environ["VAR2"] = "value2"

# Ensure that the environment variables were set correctly
self.assertEqual(mock_environ.get("VAR1"), "value1")
self.assertEqual(mock_environ.get("VAR2"), "value2")

# Check the parsed arguments
self.assertEqual(len(result_args), 1)
self.assertIsInstance(result_args[0], MyDataclass)
self.assertEqual(result_args[0].arg1, 2)
self.assertEqual(result_args[0].arg2, "value")

@patch("builtins.open", mock_open(read_data="env: not_a_dict"))
@patch("yaml.safe_load")
def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load):
"""Test parse_args_and_config method when the 'env' field is not a dictionary."""
mock_yaml_load.return_value = {"env": "not_a_dict"}

parser = TrlParser(dataclass_types=[MyDataclass])

args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"]

with self.assertRaises(ValueError) as context:
parser.parse_args_and_config(args)

self.assertEqual(str(context.exception), "`env` field should be a dict in the YAML file.")

@patch("builtins.open", mock_open(read_data=""))
@patch("yaml.safe_load")
def test_parse_args_and_config_without_config(self, mock_yaml_load):
"""Test parse_args_and_config without the `--config` argument."""
parser = TrlParser(dataclass_types=[MyDataclass])

args = ["--arg1", "2", "--arg2", "value"]

# Simulate no config, just parse args normally
result_args = parser.parse_args_and_config(args)

# Check that the arguments are parsed as is
self.assertEqual(len(result_args), 1)
self.assertIsInstance(result_args[0], MyDataclass)
self.assertEqual(result_args[0].arg1, 2)
self.assertEqual(result_args[0].arg2, "value")

def test_set_defaults_with_config(self):
"""Test set_defaults_with_config updates the defaults."""
parser = TrlParser(dataclass_types=[MyDataclass])

# Update defaults
parser.set_defaults_with_config(arg1=42)

# Ensure the default value is updated
result_args = parser.parse_args_and_config([])
self.assertEqual(len(result_args), 1)
self.assertIsInstance(result_args[0], MyDataclass)
self.assertEqual(result_args[0].arg1, 42)

def test_parse_args_and_config_with_remaining_strings(self):
parser = TrlParser(dataclass_types=[MyDataclass])

args = ["--arg1", "2", "--arg2", "value", "remaining"]

# Simulate no config, just parse args normally
result_args = parser.parse_args_and_config(args, return_remaining_strings=True)

# Check that the arguments are parsed as is
self.assertEqual(len(result_args), 2)
self.assertIsInstance(result_args[0], MyDataclass)
self.assertEqual(result_args[0].arg1, 2)
self.assertEqual(result_args[0].arg2, "value")
self.assertEqual(result_args[1], ["remaining"])
177 changes: 113 additions & 64 deletions trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,29 @@
import os
import subprocess
import sys
from argparse import Namespace
import warnings
from dataclasses import dataclass, field
from typing import Iterable, Optional, Union

import yaml
from transformers import HfArgumentParser
from transformers.hf_argparser import DataClass, DataClassType


logger = logging.getLogger(__name__)


class YamlConfigParser:
def parse_and_set_env(self, config_path):
""" """

def __init__(self) -> None:
warnings.warn(
"The `YamlConfigParser` class is deprecated and will be removed in version 0.14. "
"If you need to use this class, please copy the code to your own project.",
DeprecationWarning,
)

def parse_and_set_env(self, config_path: str) -> dict:
with open(config_path) as yaml_file:
config = yaml.safe_load(yaml_file)

Expand Down Expand Up @@ -152,88 +163,126 @@ class ChatArguments:

class TrlParser(HfArgumentParser):
"""
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
parsers for users that pass a valid `config` field and merge the values that are set in the config
with the processed parsers.
A subclass of [`transformers.HfArgumentParser`] designed for parsing command-line arguments with dataclass-backed
configurations, while also supporting configuration file loading and environment variable management.

Args:
parsers (`List[argparse.ArgumentParser]`):
List of parsers.
ignore_extra_args (`bool`):
Whether to ignore extra arguments passed by the config
and not raise errors.
dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`):
The dataclass types to use for argument parsing.
**kwargs:
Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor.

Examples:

```yaml
# config.yaml
env:
VAR1: value1
arg1: 23
```

```python
# main.py
import os
from dataclasses import dataclass
from trl import TrlParser

@dataclass
class MyArguments:
arg1: int
arg2: str = "alpha"

parser = TrlParser(dataclass_types=[MyArguments])
training_args = parser.parse_args_and_config()

print(training_args, os.environ.get("VAR1"))
```

```bash
$ python main.py --config config.yaml
(MyArguments(arg1=23, arg2='alpha'),) value1

$ python main.py --arg1 5 --arg2 beta
(MyArguments(arg1=5, arg2='beta'),) None
```
"""

def __init__(self, parsers, ignore_extra_args=False):
super().__init__(parsers)
self.yaml_parser = YamlConfigParser()
self.ignore_extra_args = ignore_extra_args
def __init__(
self,
dataclass_types: Union[DataClassType, Iterable[DataClassType]],
ignore_extra_args: Optional[bool] = None,
**kwargs,
):
super().__init__(dataclass_types=dataclass_types, **kwargs)
if ignore_extra_args is not None:
warnings.warn(
"The `ignore_extra_args` parameter is deprecated and will be removed in version 0.14. "
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
"It is no longer functional and can be safely removed from your code.",
DeprecationWarning,
)

# Check that none of the dataclasses have the "config" field
for dataclass_type in dataclass_types:
if "config" in dataclass_type.__dataclass_fields__:
raise ValueError(
f"Dataclass {dataclass_type.__name__} has a field named 'config'. This field is reserved for the "
f"config file path and should not be used in the dataclass."
)

def post_process_dataclasses(self, dataclasses):
"""
Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments.
"""

training_args = trl_args = None
training_args_index = None

for i, dataclass_obj in enumerate(dataclasses):
if dataclass_obj.__class__.__name__ == "TrainingArguments":
training_args = dataclass_obj
training_args_index = i
elif dataclass_obj.__class__.__name__ in ("SFTScriptArguments", "DPOScriptArguments"):
trl_args = dataclass_obj
else:
...

if trl_args is not None and training_args is not None:
training_args.gradient_checkpointing_kwargs = dict(
use_reentrant=trl_args.gradient_checkpointing_use_reentrant
)
dataclasses[training_args_index] = training_args

warnings.warn(
"The `post_process_dataclasses` method is deprecated and will be removed in version 0.14. "
"It is no longer functional and can be safely removed from your code.",
DeprecationWarning,
)
return dataclasses

def parse_args_and_config(self, return_remaining_strings=False):
def parse_args_and_config(
self, args: Optional[Iterable[str]] = None, return_remaining_strings: bool = False
) -> tuple[DataClass, ...]:
"""
Parse the command line arguments and the config file.
"""
yaml_config = None
if "--config" in sys.argv:
config_index = sys.argv.index("--config")

_ = sys.argv.pop(config_index) # --config
config_path = sys.argv.pop(config_index) # path to config
yaml_config = self.yaml_parser.parse_and_set_env(config_path)

self.set_defaults_with_config(**yaml_config)

outputs = self.parse_args_into_dataclasses(return_remaining_strings=return_remaining_strings)
This method is a wrapper around the `parse_args_into_dataclasses` method that also parses the config file
specified in the command line arguments with the `--config` flag. The config file should be a YAML file with
the arguments to be parsed. The method will set the environment variables specified in the `env` field of the
config file and then parse the arguments from the config file and the command line.
"""
args = list(args) if args is not None else sys.argv[1:]
if "--config" in args:
# Get the config file path from
config_index = args.index("--config")
args.pop(config_index) # remove the --config flag
config_path = args.pop(config_index) # get the path to the config file
with open(config_path) as yaml_file:
config = yaml.safe_load(yaml_file)

# Set the environment variables specified in the config file
if "env" in config:
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
env_vars = config.pop("env", {})
if not isinstance(env_vars, dict):
raise ValueError("`env` field should be a dict in the YAML file.")
for key, value in env_vars.items():
os.environ[key] = str(value)

if yaml_config is None:
return outputs
# Set the defaults from the config values
self.set_defaults_with_config(**config)

if return_remaining_strings:
# if we have extra yaml config and command line strings
# outputs[-1] is remaining command line strings
# outputs[-2] is remaining yaml config as Namespace
# combine them into remaining strings object
remaining_strings = outputs[-1] + [f"{key}: {value}" for key, value in vars(outputs[-2]).items()]
return outputs[:-2], remaining_strings
else:
# outputs[-1] is either remaining yaml config as Namespace or parsed config as Dataclass
if isinstance(outputs[-1], Namespace) and not self.ignore_extra_args:
remaining_args = vars(outputs[-1])
raise ValueError(f"Some specified config arguments are not used by the TrlParser: {remaining_args}")
return self.parse_args_into_dataclasses(args=args, return_remaining_strings=return_remaining_strings)

return outputs
def set_defaults_with_config(self, **kwargs) -> None:
"""
Overrides the parser's default values with those provided via keyword arguments.

def set_defaults_with_config(self, **kwargs):
"""Defaults we're setting with config allow us to change to required = False"""
Any argument with an updated default will also be marked as not required
if it was previously required.
"""
self._defaults.update(kwargs)

# if these defaults match any existing arguments, replace
# the previous default on the object with the new one
# If an argument is in the kwargs, update its default and set it as not required
for action in self._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
Expand Down
4 changes: 0 additions & 4 deletions trl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional


@dataclass
Expand All @@ -28,8 +27,6 @@ class ScriptArguments:
Dataset split to use for training.
dataset_test_split (`str`, *optional*, defaults to `"test"`):
Dataset split to use for evaluation.
config (`str` or `None`, *optional*, defaults to `None`):
Path to the optional config file.
gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`):
Whether to apply `use_reentrant` for gradient_checkpointing.
ignore_bias_buffers (`bool`, *optional*, defaults to `False`):
Expand All @@ -40,6 +37,5 @@ class ScriptArguments:
dataset_name: str
dataset_train_split: str = "train"
dataset_test_split: str = "test"
config: Optional[str] = None
gradient_checkpointing_use_reentrant: bool = False
ignore_bias_buffers: bool = False
Loading