From 68107bc4c00cbd376594b963e49ecf3cb555edd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 28 Nov 2024 15:10:34 +0000 Subject: [PATCH] Refactor TrlParser class to improve code organization and readability --- trl/commands/cli_utils.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index b3ce930479..384daf4927 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -151,26 +151,29 @@ class ChatArguments: class TrlParser(HfArgumentParser): + """ + The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config + parsers for users that pass a valid `config` field and merge the values that are set in the config + with the processed parsers. + + Args: + parsers (`List[argparse.ArgumentParser]`): + List of parsers. + ignore_extra_args (`bool`): + Whether to ignore extra arguments passed by the config + and not raise errors. + """ + def __init__(self, parsers, ignore_extra_args=False): - """ - The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config - parsers for users that pass a valid `config` field and merge the values that are set in the config - with the processed parsers. - - Args: - parsers (`List[argparse.ArgumentParser`]): - List of parsers. - ignore_extra_args (`bool`): - Whether to ignore extra arguments passed by the config - and not raise errors. - """ super().__init__(parsers) self.yaml_parser = YamlConfigParser() self.ignore_extra_args = ignore_extra_args def post_process_dataclasses(self, dataclasses): - # Apply additional post-processing in case some arguments needs a special - # care + """ + Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments. + """ + training_args = trl_args = None training_args_index = None @@ -192,6 +195,9 @@ def post_process_dataclasses(self, dataclasses): return dataclasses def parse_args_and_config(self, return_remaining_strings=False): + """ + Parse the command line arguments and the config file. + """ yaml_config = None if "--config" in sys.argv: config_index = sys.argv.index("--config")