diff --git a/docs/source/script_utils.md b/docs/source/script_utils.md index 344d13aaef..aba81bf9f3 100644 --- a/docs/source/script_utils.md +++ b/docs/source/script_utils.md @@ -7,3 +7,6 @@ ## TrlParser [[autodoc]] TrlParser + - parse_args_and_config + - parse_args_into_dataclasses + - set_defaults_with_config diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py new file mode 100644 index 0000000000..a2343a2930 --- /dev/null +++ b/tests/test_cli_utils.py @@ -0,0 +1,165 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from dataclasses import dataclass +from unittest.mock import mock_open, patch + +from trl import TrlParser + + +@dataclass +class MyDataclass: + arg1: int + arg2: str = "default" + + +@dataclass +class InvalidDataclass: + config: str # This should raise an error in the TrlParser + + +class TestTrlParser(unittest.TestCase): + def test_init_without_config_field(self): + """Test initialization without 'config' field in the dataclasses.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + self.assertIsInstance(parser, TrlParser) + + def test_init_with_config_field(self): + """Test initialization with a 'config' field in the dataclass (should raise ValueError).""" + with self.assertRaises(ValueError) as context: + TrlParser(dataclass_types=[InvalidDataclass]) + self.assertTrue("has a field named 'config'" in str(context.exception)) + + @patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2\narg1: 2")) + @patch("yaml.safe_load") + @patch("os.environ", new_callable=dict) # Mock os.environ as a dictionary + def test_parse_args_and_config_with_valid_config(self, mock_environ, mock_yaml_load): + """Test parse_args_and_config method with valid arguments and config.""" + mock_yaml_load.return_value = {"env": {"VAR1": "value1", "VAR2": "value2"}, "arg1": 2} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg2", "value", "--config", "config.yaml"] # don't set arg1 to test default value + + # Simulate the config being loaded and environment variables being set + result_args = parser.parse_args_and_config(args) + + # Set the environment variables using the mock + mock_environ["VAR1"] = "value1" + mock_environ["VAR2"] = "value2" + + # Ensure that the environment variables were set correctly + self.assertEqual(mock_environ.get("VAR1"), "value1") + self.assertEqual(mock_environ.get("VAR2"), "value2") + + # Check the parsed arguments + self.assertEqual(len(result_args), 1) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 2) + self.assertEqual(result_args[0].arg2, "value") + + @patch("builtins.open", mock_open(read_data="arg1: 2")) + @patch("yaml.safe_load") + def test_parse_args_and_arg_override_config(self, mock_yaml_load): + """Test parse_args_and_config method and check that arguments override the config.""" + mock_yaml_load.return_value = {"arg1": 2} # this arg is meant to be overridden + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "3", "--config", "config.yaml"] # override arg1 default with 3 + + # Simulate the config being loaded and arguments being passed + result_args = parser.parse_args_and_config(args) + + # Check the parsed arguments + self.assertEqual(len(result_args), 1) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 3) + + @patch("builtins.open", mock_open(read_data="env: not_a_dict")) + @patch("yaml.safe_load") + def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load): + """Test parse_args_and_config method when the 'env' field is not a dictionary.""" + mock_yaml_load.return_value = {"env": "not_a_dict"} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"] + + with self.assertRaises(ValueError) as context: + parser.parse_args_and_config(args) + + self.assertEqual(str(context.exception), "`env` field should be a dict in the YAML file.") + + def test_parse_args_and_config_without_config(self): + """Test parse_args_and_config without the `--config` argument.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value"] + + # Simulate no config, just parse args normally + result_args = parser.parse_args_and_config(args) + + # Check that the arguments are parsed as is + self.assertEqual(len(result_args), 1) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 2) + self.assertEqual(result_args[0].arg2, "value") + + def test_set_defaults_with_config(self): + """Test set_defaults_with_config updates the defaults.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + + # Update defaults + parser.set_defaults_with_config(arg1=42) + + # Ensure the default value is updated + result_args = parser.parse_args_and_config([]) + self.assertEqual(len(result_args), 1) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 42) + + def test_parse_args_and_config_with_remaining_strings(self): + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value", "remaining"] + + # Simulate no config, just parse args normally + result_args = parser.parse_args_and_config(args, return_remaining_strings=True) + + # Check that the arguments are parsed as is + self.assertEqual(len(result_args), 2) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 2) + self.assertEqual(result_args[0].arg2, "value") + self.assertEqual(result_args[1], ["remaining"]) + + @patch("builtins.open", mock_open(read_data="remaining_string_in_config: abc")) + @patch("yaml.safe_load") + def test_parse_args_and_config_with_remaining_strings_in_config_and_args(self, mock_yaml_load): + mock_yaml_load.return_value = {"remaining_string_in_config": "abc"} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--remaining_string_in_args", "def", "--config", "config.yaml"] + + # Simulate the config being loaded and arguments being passed + result_args = parser.parse_args_and_config(args, return_remaining_strings=True) + + # Check that the arguments are parsed as is + self.assertEqual(len(result_args), 2) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 2) + self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"]) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index 384daf4927..772accf0bf 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -19,18 +19,30 @@ import os import subprocess import sys -from argparse import Namespace +import warnings from dataclasses import dataclass, field +from typing import Iterable, Optional, Union import yaml from transformers import HfArgumentParser +from transformers.hf_argparser import DataClass, DataClassType +from transformers.utils.deprecation import deprecate_kwarg logger = logging.getLogger(__name__) class YamlConfigParser: - def parse_and_set_env(self, config_path): + """ """ + + def __init__(self) -> None: + warnings.warn( + "The `YamlConfigParser` class is deprecated and will be removed in version 0.14. " + "If you need to use this class, please copy the code to your own project.", + DeprecationWarning, + ) + + def parse_and_set_env(self, config_path: str) -> dict: with open(config_path) as yaml_file: config = yaml.safe_load(yaml_file) @@ -152,92 +164,146 @@ class ChatArguments: class TrlParser(HfArgumentParser): """ - The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config - parsers for users that pass a valid `config` field and merge the values that are set in the config - with the processed parsers. + A subclass of [`transformers.HfArgumentParser`] designed for parsing command-line arguments with dataclass-backed + configurations, while also supporting configuration file loading and environment variable management. Args: - parsers (`List[argparse.ArgumentParser]`): - List of parsers. - ignore_extra_args (`bool`): - Whether to ignore extra arguments passed by the config - and not raise errors. + dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`): + Dataclass types to use for argument parsing. + **kwargs: + Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor. + + Examples: + + ```yaml + # config.yaml + env: + VAR1: value1 + arg1: 23 + ``` + + ```python + # main.py + import os + from dataclasses import dataclass + from trl import TrlParser + + @dataclass + class MyArguments: + arg1: int + arg2: str = "alpha" + + parser = TrlParser(dataclass_types=[MyArguments]) + training_args = parser.parse_args_and_config() + + print(training_args, os.environ.get("VAR1")) + ``` + + ```bash + $ python main.py --config config.yaml + (MyArguments(arg1=23, arg2='alpha'),) value1 + + $ python main.py --arg1 5 --arg2 beta + (MyArguments(arg1=5, arg2='beta'),) None + ``` """ - def __init__(self, parsers, ignore_extra_args=False): - super().__init__(parsers) - self.yaml_parser = YamlConfigParser() - self.ignore_extra_args = ignore_extra_args + @deprecate_kwarg( + "ignore_extra_args", + "0.14.0", + warn_if_greater_or_equal_version=True, + additional_message="Use the `return_remaining_strings` in the `parse_args_and_config` method instead.", + ) + def __init__( + self, + dataclass_types: Union[DataClassType, Iterable[DataClassType]], + ignore_extra_args: Optional[bool] = None, + **kwargs, + ): + super().__init__(dataclass_types=dataclass_types, **kwargs) + self._ignore_extra_args = ignore_extra_args + + # Check that none of the dataclasses have the "config" field + for dataclass_type in dataclass_types: + if "config" in dataclass_type.__dataclass_fields__: + raise ValueError( + f"Dataclass {dataclass_type.__name__} has a field named 'config'. This field is reserved for the " + f"config file path and should not be used in the dataclass." + ) def post_process_dataclasses(self, dataclasses): """ Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments. """ - - training_args = trl_args = None - training_args_index = None - - for i, dataclass_obj in enumerate(dataclasses): - if dataclass_obj.__class__.__name__ == "TrainingArguments": - training_args = dataclass_obj - training_args_index = i - elif dataclass_obj.__class__.__name__ in ("SFTScriptArguments", "DPOScriptArguments"): - trl_args = dataclass_obj - else: - ... - - if trl_args is not None and training_args is not None: - training_args.gradient_checkpointing_kwargs = dict( - use_reentrant=trl_args.gradient_checkpointing_use_reentrant - ) - dataclasses[training_args_index] = training_args - + warnings.warn( + "The `post_process_dataclasses` method is deprecated and will be removed in version 0.14. " + "It is no longer functional and can be safely removed from your code.", + DeprecationWarning, + ) return dataclasses - def parse_args_and_config(self, return_remaining_strings=False): + def parse_args_and_config( + self, args: Optional[Iterable[str]] = None, return_remaining_strings: bool = False + ) -> tuple[DataClass, ...]: """ - Parse the command line arguments and the config file. - """ - yaml_config = None - if "--config" in sys.argv: - config_index = sys.argv.index("--config") - - _ = sys.argv.pop(config_index) # --config - config_path = sys.argv.pop(config_index) # path to config - yaml_config = self.yaml_parser.parse_and_set_env(config_path) + Parse command-line args and config file into instances of the specified dataclass types. - self.set_defaults_with_config(**yaml_config) + This method wraps [`transformers.HfArgumentParser.parse_args_into_dataclasses`] and also parses the config file + specified with the `--config` flag. The config file (in YAML format) provides argument values that replace the + default values in the dataclasses. Command line arguments can override values set by the config file. The + method also sets any environment variables specified in the `env` field of the config file. + """ + if self._ignore_extra_args is not None: + return_remaining_strings = not self._ignore_extra_args + + args = list(args) if args is not None else sys.argv[1:] + if "--config" in args: + # Get the config file path from + config_index = args.index("--config") + args.pop(config_index) # remove the --config flag + config_path = args.pop(config_index) # get the path to the config file + with open(config_path) as yaml_file: + config = yaml.safe_load(yaml_file) + + # Set the environment variables specified in the config file + if "env" in config: + env_vars = config.pop("env", {}) + if not isinstance(env_vars, dict): + raise ValueError("`env` field should be a dict in the YAML file.") + for key, value in env_vars.items(): + os.environ[key] = str(value) - outputs = self.parse_args_into_dataclasses(return_remaining_strings=return_remaining_strings) + # Set the defaults from the config values + config_remaining_strings = self.set_defaults_with_config(**config) + else: + config_remaining_strings = [] - if yaml_config is None: - return outputs + # Parse the arguments from the command line + output = self.parse_args_into_dataclasses(args=args, return_remaining_strings=return_remaining_strings) + # Merge remaining strings from the config file with the remaining strings from the command line if return_remaining_strings: - # if we have extra yaml config and command line strings - # outputs[-1] is remaining command line strings - # outputs[-2] is remaining yaml config as Namespace - # combine them into remaining strings object - remaining_strings = outputs[-1] + [f"{key}: {value}" for key, value in vars(outputs[-2]).items()] - return outputs[:-2], remaining_strings + args_remaining_strings = output[-1] + return output[:-1] + (config_remaining_strings + args_remaining_strings,) else: - # outputs[-1] is either remaining yaml config as Namespace or parsed config as Dataclass - if isinstance(outputs[-1], Namespace) and not self.ignore_extra_args: - remaining_args = vars(outputs[-1]) - raise ValueError(f"Some specified config arguments are not used by the TrlParser: {remaining_args}") + return output - return outputs + def set_defaults_with_config(self, **kwargs) -> list[str]: + """ + Overrides the parser's default values with those provided via keyword arguments. - def set_defaults_with_config(self, **kwargs): - """Defaults we're setting with config allow us to change to required = False""" - self._defaults.update(kwargs) + Any argument with an updated default will also be marked as not required + if it was previously required. - # if these defaults match any existing arguments, replace - # the previous default on the object with the new one + Returns a list of strings that were not consumed by the parser. + """ + # If an argument is in the kwargs, update its default and set it as not required for action in self._actions: if action.dest in kwargs: - action.default = kwargs[action.dest] + action.default = kwargs.pop(action.dest) action.required = False + remaining_strings = [item for key, value in kwargs.items() for item in [f"--{key}", str(value)]] + return remaining_strings def get_git_commit_hash(package_name): diff --git a/trl/utils.py b/trl/utils.py index eaea8c78aa..6888e56229 100644 --- a/trl/utils.py +++ b/trl/utils.py @@ -13,7 +13,6 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional @dataclass @@ -28,8 +27,6 @@ class ScriptArguments: Dataset split to use for training. dataset_test_split (`str`, *optional*, defaults to `"test"`): Dataset split to use for evaluation. - config (`str` or `None`, *optional*, defaults to `None`): - Path to the optional config file. gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): Whether to apply `use_reentrant` for gradient_checkpointing. ignore_bias_buffers (`bool`, *optional*, defaults to `False`): @@ -40,6 +37,5 @@ class ScriptArguments: dataset_name: str dataset_train_split: str = "train" dataset_test_split: str = "test" - config: Optional[str] = None gradient_checkpointing_use_reentrant: bool = False ignore_bias_buffers: bool = False