Skip to content

Commit

Permalink
Refactor TrlParser class to improve code organization and readability
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Nov 28, 2024
1 parent 037f3fd commit 68107bc
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,26 +151,29 @@ class ChatArguments:


class TrlParser(HfArgumentParser):
"""
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
parsers for users that pass a valid `config` field and merge the values that are set in the config
with the processed parsers.
Args:
parsers (`List[argparse.ArgumentParser]`):
List of parsers.
ignore_extra_args (`bool`):
Whether to ignore extra arguments passed by the config
and not raise errors.
"""

def __init__(self, parsers, ignore_extra_args=False):
"""
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
parsers for users that pass a valid `config` field and merge the values that are set in the config
with the processed parsers.
Args:
parsers (`List[argparse.ArgumentParser`]):
List of parsers.
ignore_extra_args (`bool`):
Whether to ignore extra arguments passed by the config
and not raise errors.
"""
super().__init__(parsers)
self.yaml_parser = YamlConfigParser()
self.ignore_extra_args = ignore_extra_args

def post_process_dataclasses(self, dataclasses):
# Apply additional post-processing in case some arguments needs a special
# care
"""
Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments.
"""

training_args = trl_args = None
training_args_index = None

Expand All @@ -192,6 +195,9 @@ def post_process_dataclasses(self, dataclasses):
return dataclasses

def parse_args_and_config(self, return_remaining_strings=False):
"""
Parse the command line arguments and the config file.
"""
yaml_config = None
if "--config" in sys.argv:
config_index = sys.argv.index("--config")
Expand Down

0 comments on commit 68107bc

Please sign in to comment.