From c77503acb94a1a47312cc0d83ca8e392b2759174 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 18:27:44 +0000 Subject: [PATCH 01/42] Refactor main function in dpo.py --- examples/scripts/dpo.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index ed14461725..2efcf290bf 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -63,10 +63,7 @@ from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -if __name__ == "__main__": - parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() - +def main(script_args, training_args, model_config): ################ # Model & Tokenizer ################### @@ -136,3 +133,9 @@ trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_and_config() + main(script_args, training_args, model_config) From 66254d8ed90226a3bb1f2470a0cfdc7f65b0c13c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 18:39:30 +0000 Subject: [PATCH 02/42] Update setup.py and add cli.py --- setup.py | 6 +++--- trl/cli.py | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 trl/cli.py diff --git a/setup.py b/setup.py index 1259911368..b427ede26f 100644 --- a/setup.py +++ b/setup.py @@ -121,11 +121,11 @@ ], url="https://github.com/huggingface/trl", entry_points={ - "console_scripts": ["trl=trl.commands.cli:main"], + "console_scripts": ["trl=trl.cli:main"], }, include_package_data=True, - package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md"]}, - packages=find_packages(exclude={"tests"}), + package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md", "examples/*"]}, + packages=find_packages(exclude={"tests", "examples"}), install_requires=REQUIRED_PKGS, extras_require=EXTRAS, python_requires=">=3.9", diff --git a/trl/cli.py b/trl/cli.py new file mode 100644 index 0000000000..71821ffda6 --- /dev/null +++ b/trl/cli.py @@ -0,0 +1,27 @@ +import argparse + +from transformers import HfArgumentParser, TrainingArguments +import sys +from pathlib import Path + +# Add the root of the project to the python path so that we can import examples scripts +path = Path(__file__).parent.parent +sys.path.append(str(path)) + + +def main(): + parser = argparse.ArgumentParser(prog="trl", description="A CLI tool for training and fine-tuning") + subparsers = parser.add_subparsers(dest="command", required=True, parser_class=HfArgumentParser) + + # 'dpo' subcommand + dpo_parser = subparsers.add_parser("dpo", help="Run the DPO training process", dataclass_types=TrainingArguments) + + args = parser.parse_args() + sys.argv = sys.argv[1:] # Remove 'trl' from sys.argv + + if args.command == "dpo": + from examples.scripts.dpo import main as dpo_main + + (training_args,) = dpo_parser.parse_args_into_dataclasses() + dpo_main(training_args) + From 49f34d16ccf368beab5a74624ea2f26699ac86f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 18:51:58 +0000 Subject: [PATCH 03/42] Add examples to package data --- MANIFEST.in | 3 ++- setup.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 26496e93f1..a108824270 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,4 +3,5 @@ include LICENSE include CONTRIBUTING.md include README.md recursive-exclude * __pycache__ -include trl/templates/*.md \ No newline at end of file +include trl/templates/*.md +include examples/* \ No newline at end of file diff --git a/setup.py b/setup.py index b427ede26f..241cc74607 100644 --- a/setup.py +++ b/setup.py @@ -124,7 +124,10 @@ "console_scripts": ["trl=trl.cli:main"], }, include_package_data=True, - package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md", "examples/*"]}, + package_data={ + "trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md", "examples/*"], + "": ["examples/*"], + }, packages=find_packages(exclude={"tests", "examples"}), install_requires=REQUIRED_PKGS, extras_require=EXTRAS, From 35ee4e60c77660c2b77eb35789446af24ffbbec4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 18:55:41 +0000 Subject: [PATCH 04/42] style --- trl/cli.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/trl/cli.py b/trl/cli.py index 71821ffda6..54ae479363 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -1,9 +1,24 @@ -import argparse +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from transformers import HfArgumentParser, TrainingArguments +import argparse import sys from pathlib import Path +from transformers import HfArgumentParser, TrainingArguments + + # Add the root of the project to the python path so that we can import examples scripts path = Path(__file__).parent.parent sys.path.append(str(path)) @@ -24,4 +39,3 @@ def main(): (training_args,) = dpo_parser.parse_args_into_dataclasses() dpo_main(training_args) - From 09e1257240a016c77e37d2375b41d49eddfcab48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 18:56:23 +0000 Subject: [PATCH 05/42] Refactor setup.py file --- setup.py | 85 ++++++++++++++++++++++++++------------------------------ 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/setup.py b/setup.py index 241cc74607..bb54357d76 100644 --- a/setup.py +++ b/setup.py @@ -68,8 +68,6 @@ Then push the change with a message 'set dev version' """ -import os - from setuptools import find_packages, setup @@ -99,47 +97,42 @@ for reqs in EXTRAS.values(): EXTRAS["dev"].extend(reqs) -try: - file_path = os.path.dirname(os.path.abspath(__file__)) - os.symlink(os.path.join(file_path, "examples/scripts"), os.path.join(file_path, "trl/commands/scripts")) - - setup( - name="trl", - license="Apache 2.0", - classifiers=[ - "Development Status :: 2 - Pre-Alpha", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Natural Language :: English", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - url="https://github.com/huggingface/trl", - entry_points={ - "console_scripts": ["trl=trl.cli:main"], - }, - include_package_data=True, - package_data={ - "trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md", "examples/*"], - "": ["examples/*"], - }, - packages=find_packages(exclude={"tests", "examples"}), - install_requires=REQUIRED_PKGS, - extras_require=EXTRAS, - python_requires=">=3.9", - long_description=open("README.md", encoding="utf-8").read(), - long_description_content_type="text/markdown", - zip_safe=False, - version=__version__, - description="Train transformer language models with reinforcement learning.", - keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf", - author="Leandro von Werra", - author_email="leandro.vonwerra@gmail.com", - ) -finally: - os.unlink(os.path.join(file_path, "trl/commands/scripts")) + +setup( + name="trl", + license="Apache 2.0", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + url="https://github.com/huggingface/trl", + entry_points={ + "console_scripts": ["trl=trl.cli:main"], + }, + include_package_data=True, + package_data={ + "trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md"], + "": ["examples/*"], + }, + packages=find_packages(exclude={"tests", "examples"}), + install_requires=REQUIRED_PKGS, + extras_require=EXTRAS, + python_requires=">=3.9", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + zip_safe=False, + version=__version__, + description="Train transformer language models with reinforcement learning.", + keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf", + author="Leandro von Werra", + author_email="leandro.vonwerra@gmail.com", +) From 4ccf137407460bfb0bbedbff070bf3467dc73e3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 18:59:54 +0000 Subject: [PATCH 06/42] Add new file t.py --- examples/t.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 examples/t.py diff --git a/examples/t.py b/examples/t.py new file mode 100644 index 0000000000..e69de29bb2 From d397ab805d1cc570dd5730f4bc31561b80d5e493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 19:18:49 +0000 Subject: [PATCH 07/42] Move dpo to package --- {examples => trl}/scripts/dpo.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {examples => trl}/scripts/dpo.py (100%) diff --git a/examples/scripts/dpo.py b/trl/scripts/dpo.py similarity index 100% rename from examples/scripts/dpo.py rename to trl/scripts/dpo.py From dad9ecc0492c502b236f172cd6b2abd4c6ac7f84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 19:20:12 +0000 Subject: [PATCH 08/42] Update MANIFEST.in and setup.py, refactor trl/cli.py --- MANIFEST.in | 3 +-- examples/t.py | 0 setup.py | 3 +-- trl/cli.py | 18 ++++++++---------- 4 files changed, 10 insertions(+), 14 deletions(-) delete mode 100644 examples/t.py diff --git a/MANIFEST.in b/MANIFEST.in index a108824270..26496e93f1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,5 +3,4 @@ include LICENSE include CONTRIBUTING.md include README.md recursive-exclude * __pycache__ -include trl/templates/*.md -include examples/* \ No newline at end of file +include trl/templates/*.md \ No newline at end of file diff --git a/examples/t.py b/examples/t.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/setup.py b/setup.py index bb54357d76..fd8ea87950 100644 --- a/setup.py +++ b/setup.py @@ -121,9 +121,8 @@ include_package_data=True, package_data={ "trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md"], - "": ["examples/*"], }, - packages=find_packages(exclude={"tests", "examples"}), + packages=find_packages(exclude={"tests", "tests.slow"}), install_requires=REQUIRED_PKGS, extras_require=EXTRAS, python_requires=">=3.9", diff --git a/trl/cli.py b/trl/cli.py index 54ae479363..b0ecf3928c 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -14,14 +14,10 @@ import argparse import sys -from pathlib import Path -from transformers import HfArgumentParser, TrainingArguments +from transformers import HfArgumentParser - -# Add the root of the project to the python path so that we can import examples scripts -path = Path(__file__).parent.parent -sys.path.append(str(path)) +from trl import DPOConfig, ModelConfig, ScriptArguments def main(): @@ -29,13 +25,15 @@ def main(): subparsers = parser.add_subparsers(dest="command", required=True, parser_class=HfArgumentParser) # 'dpo' subcommand - dpo_parser = subparsers.add_parser("dpo", help="Run the DPO training process", dataclass_types=TrainingArguments) + dpo_parser = subparsers.add_parser( + "dpo", help="Run the DPO training process", dataclass_types=(ScriptArguments, DPOConfig, ModelConfig) + ) args = parser.parse_args() sys.argv = sys.argv[1:] # Remove 'trl' from sys.argv if args.command == "dpo": - from examples.scripts.dpo import main as dpo_main + from trl.scripts.dpo import main as dpo_main - (training_args,) = dpo_parser.parse_args_into_dataclasses() - dpo_main(training_args) + script_args, training_args, model_config = dpo_parser.parse_args_and_config() + dpo_main(script_args, training_args, model_config) From 3024d5bfb460653d23917a9484b2cf3b551cc96d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 19:25:27 +0000 Subject: [PATCH 09/42] Add __init__.py to trl/scripts directory --- trl/scripts/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 trl/scripts/__init__.py diff --git a/trl/scripts/__init__.py b/trl/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 60583c2cb874bf99e75b21c84ae6920f38b83648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 21 Nov 2024 19:28:48 +0000 Subject: [PATCH 10/42] Add license header to __init__.py --- trl/scripts/__init__.py | 14 ++++++++++++++ trl/scripts/dpo.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/trl/scripts/__init__.py b/trl/scripts/__init__.py index e69de29bb2..adfc257a24 100644 --- a/trl/scripts/__init__.py +++ b/trl/scripts/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 2efcf290bf..312eb1298d 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -13,7 +13,7 @@ # limitations under the License. """ # Full training -python examples/scripts/dpo.py \ +python trl/scripts/dpo.py \ --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --learning_rate 5.0e-7 \ From 5eb1adfc1556526efbd8d2013c2fb0eff5016243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 22 Nov 2024 09:16:24 +0000 Subject: [PATCH 11/42] File moved instruction --- trl/trainer/dpo.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 trl/trainer/dpo.py diff --git a/trl/trainer/dpo.py b/trl/trainer/dpo.py new file mode 100644 index 0000000000..b4fbf8ed11 --- /dev/null +++ b/trl/trainer/dpo.py @@ -0,0 +1 @@ +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py \ No newline at end of file From 793ce44b1b88005a2b478d3994b2e87d99c6de47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 22 Nov 2024 09:19:01 +0000 Subject: [PATCH 12/42] Add Apache License and update file path --- trl/trainer/dpo.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/trl/trainer/dpo.py b/trl/trainer/dpo.py index b4fbf8ed11..4352ec603c 100644 --- a/trl/trainer/dpo.py +++ b/trl/trainer/dpo.py @@ -1 +1,18 @@ -# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py \ No newline at end of file +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +############################################################################################### +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py # +############################################################################################### From bf27b3688d0704faa6162cc7b5c7795136d42a3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 25 Nov 2024 13:03:42 +0000 Subject: [PATCH 13/42] Move dpo.py to new location --- {trl/trainer => examples/scripts}/dpo.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {trl/trainer => examples/scripts}/dpo.py (100%) diff --git a/trl/trainer/dpo.py b/examples/scripts/dpo.py similarity index 100% rename from trl/trainer/dpo.py rename to examples/scripts/dpo.py From adac644f7fe1a5a37f9cca10044ed81b594eae32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 25 Nov 2024 13:12:03 +0000 Subject: [PATCH 14/42] Refactor CLI and DPO script --- trl/cli.py | 34 ++++++++++++++++++++++------------ trl/scripts/dpo.py | 15 ++++++++++++++- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/trl/cli.py b/trl/cli.py index b0ecf3928c..4722a5c87e 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -13,27 +13,37 @@ # limitations under the License. import argparse +import os import sys -from transformers import HfArgumentParser +from accelerate.commands.launch import launch_command, launch_command_parser -from trl import DPOConfig, ModelConfig, ScriptArguments +from trl.scripts.dpo import make_parser as make_dpo_parser def main(): - parser = argparse.ArgumentParser(prog="trl", description="A CLI tool for training and fine-tuning") - subparsers = parser.add_subparsers(dest="command", required=True, parser_class=HfArgumentParser) + parser = argparse.ArgumentParser("TRL CLI", usage="trl", allow_abbrev=False) - # 'dpo' subcommand - dpo_parser = subparsers.add_parser( - "dpo", help="Run the DPO training process", dataclass_types=(ScriptArguments, DPOConfig, ModelConfig) - ) + # Add the subparsers + subparsers = parser.add_subparsers(help="available commands", dest="command") + # Add the subparsers for every script + make_dpo_parser(subparsers) + + # Parse the arguments args = parser.parse_args() - sys.argv = sys.argv[1:] # Remove 'trl' from sys.argv if args.command == "dpo": - from trl.scripts.dpo import main as dpo_main + # Get the default args for the launch command + dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py") + args = launch_command_parser().parse_args([dpo_training_script]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "dpo" + + # Launch the training + launch_command(args) + - script_args, training_args, model_config = dpo_parser.parse_args_and_config() - dpo_main(script_args, training_args, model_config) +if __name__ == "__main__": + main() diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 312eb1298d..822048cff6 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -46,6 +46,8 @@ --lora_alpha 16 """ +import argparse + import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -135,7 +137,18 @@ def main(script_args, training_args, model_config): trainer.push_to_hub(dataset_name=script_args.dataset_name) +def make_parser(subparsers: argparse._SubParsersAction = None): + if subparsers is not None: + parser = subparsers.add_parser("dpo", help="Run the DPO training script") + else: + parser = parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) + + if subparsers is not None: + parser.set_defaults(func=main) + return parser + + if __name__ == "__main__": - parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) + parser = make_parser() script_args, training_args, model_config = parser.parse_args_and_config() main(script_args, training_args, model_config) From a15a41c4237652d0bf6b76eb23f132c13ca37287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 28 Nov 2024 18:44:04 +0000 Subject: [PATCH 15/42] Refactor import structure in scripts package --- trl/__init__.py | 6 +-- trl/scripts/__init__.py | 15 +++++++ .../cli_utils.py => scripts/utils.py} | 30 +++++++++++++ trl/utils.py | 45 ------------------- 4 files changed, 47 insertions(+), 49 deletions(-) rename trl/{commands/cli_utils.py => scripts/utils.py} (89%) delete mode 100644 trl/utils.py diff --git a/trl/__init__.py b/trl/__init__.py index 25cf4fb291..374384e834 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -20,7 +20,7 @@ _import_structure = { - "commands.cli_utils": ["DPOScriptArguments", "SFTScriptArguments", "TrlParser", "init_zero_verbose"], + "scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"], "core": ["set_seed"], "data_utils": [ "apply_chat_template", @@ -94,7 +94,6 @@ ], "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], - "utils": ["ScriptArguments"], } try: @@ -114,7 +113,6 @@ _import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"]) if TYPE_CHECKING: - from .commands.cli_utils import DPOScriptArguments, SFTScriptArguments, TrlParser, init_zero_verbose from .core import set_seed from .data_utils import ( apply_chat_template, @@ -136,6 +134,7 @@ create_reference_model, setup_chat_format, ) + from .scripts import ScriptArguments, TrlParser, init_zero_verbose from .trainer import ( AlignPropConfig, AlignPropTrainer, @@ -184,7 +183,6 @@ ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config - from .utils import ScriptArguments try: if not is_diffusers_available(): diff --git a/trl/scripts/__init__.py b/trl/scripts/__init__.py index adfc257a24..e79624121d 100644 --- a/trl/scripts/__init__.py +++ b/trl/scripts/__init__.py @@ -12,3 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"], +} + +if TYPE_CHECKING: + from .utils import ScriptArguments, TrlParser, init_zero_verbose +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/commands/cli_utils.py b/trl/scripts/utils.py similarity index 89% rename from trl/commands/cli_utils.py rename to trl/scripts/utils.py index 384daf4927..0d4a9ec41d 100644 --- a/trl/commands/cli_utils.py +++ b/trl/scripts/utils.py @@ -21,6 +21,7 @@ import sys from argparse import Namespace from dataclasses import dataclass, field +from typing import Optional import yaml from transformers import HfArgumentParser @@ -29,6 +30,35 @@ logger = logging.getLogger(__name__) +@dataclass +class ScriptArguments: + """ + Arguments common to all scripts. + + Args: + dataset_name (`str`): + Dataset name. + dataset_train_split (`str`, *optional*, defaults to `"train"`): + Dataset split to use for training. + dataset_test_split (`str`, *optional*, defaults to `"test"`): + Dataset split to use for evaluation. + config (`str` or `None`, *optional*, defaults to `None`): + Path to the optional config file. + gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): + Whether to apply `use_reentrant` for gradient_checkpointing. + ignore_bias_buffers (`bool`, *optional*, defaults to `False`): + Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar + type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. + """ + + dataset_name: str + dataset_train_split: str = "train" + dataset_test_split: str = "test" + config: Optional[str] = None + gradient_checkpointing_use_reentrant: bool = False + ignore_bias_buffers: bool = False + + class YamlConfigParser: def parse_and_set_env(self, config_path): with open(config_path) as yaml_file: diff --git a/trl/utils.py b/trl/utils.py deleted file mode 100644 index eaea8c78aa..0000000000 --- a/trl/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class ScriptArguments: - """ - Arguments common to all scripts. - - Args: - dataset_name (`str`): - Dataset name. - dataset_train_split (`str`, *optional*, defaults to `"train"`): - Dataset split to use for training. - dataset_test_split (`str`, *optional*, defaults to `"test"`): - Dataset split to use for evaluation. - config (`str` or `None`, *optional*, defaults to `None`): - Path to the optional config file. - gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): - Whether to apply `use_reentrant` for gradient_checkpointing. - ignore_bias_buffers (`bool`, *optional*, defaults to `False`): - Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar - type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. - """ - - dataset_name: str - dataset_train_split: str = "train" - dataset_test_split: str = "test" - config: Optional[str] = None - gradient_checkpointing_use_reentrant: bool = False - ignore_bias_buffers: bool = False From 7a0a4f078e7aea9e40f994a26c4dd1257d845ec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 4 Dec 2024 18:33:14 +0000 Subject: [PATCH 16/42] env --- trl/cli.py | 7 ++++- trl/scripts/env.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 trl/scripts/env.py diff --git a/trl/cli.py b/trl/cli.py index 4722a5c87e..ffc920a433 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -18,7 +18,8 @@ from accelerate.commands.launch import launch_command, launch_command_parser -from trl.scripts.dpo import make_parser as make_dpo_parser +from .scripts.dpo import make_parser as make_dpo_parser +from .scripts.env import print_env def main(): @@ -29,6 +30,7 @@ def main(): # Add the subparsers for every script make_dpo_parser(subparsers) + subparsers.add_parser("env", help="Print the environment information") # Parse the arguments args = parser.parse_args() @@ -44,6 +46,9 @@ def main(): # Launch the training launch_command(args) + elif args.command == "env": + print_env() + if __name__ == "__main__": main() diff --git a/trl/scripts/env.py b/trl/scripts/env.py new file mode 100644 index 0000000000..c2000d1db4 --- /dev/null +++ b/trl/scripts/env.py @@ -0,0 +1,72 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import platform +from importlib.metadata import version + +import torch +from accelerate.commands.config import default_config_file, load_config_from_file +from transformers import is_bitsandbytes_available +from transformers.utils import is_liger_kernel_available, is_openai_available, is_peft_available + +from .. import __version__ +from ..import_utils import is_deepspeed_available, is_diffusers_available, is_llm_blender_available +from .utils import get_git_commit_hash + + +def print_env(): + if torch.cuda.is_available(): + devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] + + accelerate_config = accelerate_config_str = "not found" + + # Get the default from the config file. + if os.path.isfile(default_config_file): + accelerate_config = load_config_from_file(default_config_file).to_dict() + + accelerate_config_str = ( + "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()]) + if isinstance(accelerate_config, dict) + else accelerate_config + ) + + commit_hash = get_git_commit_hash("trl") + + info = { + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version": version("torch"), + "CUDA device(s)": ", ".join(devices) if torch.cuda.is_available() else "not available", + "Transformers version": version("transformers"), + "Accelerate version": version("accelerate"), + "Accelerate config": accelerate_config_str, + "Datasets version": version("datasets"), + "HF Hub version": version("huggingface_hub"), + "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__, + "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed", + "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed", + "Diffusers version": version("diffusers") if is_diffusers_available() else "not installed", + "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed", + "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed", + "OpenAI version": version("openai") if is_openai_available() else "not installed", + "PEFT version": version("peft") if is_peft_available() else "not installed", + } + + info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa + + +if __name__ == "__main__": + print_env() From 167f23f96ff7920989d313434d91ddf1df1525a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 4 Dec 2024 18:33:35 +0000 Subject: [PATCH 17/42] rm config from chat arg --- trl/scripts/utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index a09bc4f055..6dcf4811fd 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -128,12 +128,6 @@ class ChatArguments: default="cpu", metadata={"help": "device to use for inference."}, ) - config: str = field( - default="default", - metadata={ - "help": "Config file used for setting the configs. If `default` uses examples/scripts/config/default_chat_config.yaml" - }, - ) examples: str = field(default=None, metadata={"help": "Empty placeholder needs to be set via config."}) # generation settings max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"}) @@ -250,6 +244,8 @@ def __init__( self._ignore_extra_args = ignore_extra_args # Check that none of the dataclasses have the "config" field + if not isinstance(dataclass_types, list): + dataclass_types = [dataclass_types] for dataclass_type in dataclass_types: if "config" in dataclass_type.__dataclass_fields__: raise ValueError( From 084e33aab49e24b37e41eaecc2c75b713bbdfd90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 4 Dec 2024 18:33:59 +0000 Subject: [PATCH 18/42] rm old cli --- trl/commands/cli.py | 129 -------------------------------------------- 1 file changed, 129 deletions(-) diff --git a/trl/commands/cli.py b/trl/commands/cli.py index 9f4266e631..adfc257a24 100644 --- a/trl/commands/cli.py +++ b/trl/commands/cli.py @@ -11,133 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import platform -import subprocess -import sys -from importlib.metadata import version -from subprocess import CalledProcessError -import torch -from accelerate.commands.config import default_config_file, load_config_from_file -from rich.console import Console -from transformers import is_bitsandbytes_available -from transformers.utils import is_liger_kernel_available, is_openai_available, is_peft_available - -from .. import __version__, is_deepspeed_available, is_diffusers_available, is_llm_blender_available -from .cli_utils import get_git_commit_hash - - -SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto", "env"] - - -def print_env(): - if torch.cuda.is_available(): - devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] - - accelerate_config = accelerate_config_str = "not found" - - # Get the default from the config file. - if os.path.isfile(default_config_file): - accelerate_config = load_config_from_file(default_config_file).to_dict() - - accelerate_config_str = ( - "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()]) - if isinstance(accelerate_config, dict) - else accelerate_config - ) - - commit_hash = get_git_commit_hash("trl") - - info = { - "Platform": platform.platform(), - "Python version": platform.python_version(), - "PyTorch version": version("torch"), - "CUDA device(s)": ", ".join(devices) if torch.cuda.is_available() else "not available", - "Transformers version": version("transformers"), - "Accelerate version": version("accelerate"), - "Accelerate config": accelerate_config_str, - "Datasets version": version("datasets"), - "HF Hub version": version("huggingface_hub"), - "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__, - "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed", - "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed", - "Diffusers version": version("diffusers") if is_diffusers_available() else "not installed", - "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed", - "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed", - "OpenAI version": version("openai") if is_openai_available() else "not installed", - "PEFT version": version("peft") if is_peft_available() else "not installed", - } - - info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()]) - print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa - - -def train(command_name): - console = Console() - # Make sure to import things locally to avoid verbose from third party libs. - with console.status("[bold purple]Welcome! Initializing the TRL CLI..."): - from trl.commands.cli_utils import init_zero_verbose - - init_zero_verbose() - command_name = sys.argv[1] - trl_examples_dir = os.path.dirname(__file__) - - command = f"accelerate launch {trl_examples_dir}/scripts/{command_name}.py {' '.join(sys.argv[2:])}" - - try: - subprocess.run( - command.split(), - text=True, - check=True, - encoding="utf-8", - cwd=os.getcwd(), - env=os.environ.copy(), - ) - except (CalledProcessError, ChildProcessError) as exc: - console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.") - raise ValueError("TRL CLI failed! Check the traceback above..") from exc - - -def chat(): - console = Console() - # Make sure to import things locally to avoid verbose from third party libs. - with console.status("[bold purple]Welcome! Initializing the TRL CLI..."): - from trl.commands.cli_utils import init_zero_verbose - - init_zero_verbose() - trl_examples_dir = os.path.dirname(__file__) - - command = f"python {trl_examples_dir}/scripts/chat.py {' '.join(sys.argv[2:])}" - - try: - subprocess.run( - command.split(), - text=True, - check=True, - encoding="utf-8", - cwd=os.getcwd(), - env=os.environ.copy(), - ) - except (CalledProcessError, ChildProcessError) as exc: - console.log("TRL - CHAT failed! See the logs above for further details.") - raise ValueError("TRL CLI failed! Check the traceback above..") from exc - - -def main(): - command_name = sys.argv[1] - - if command_name in ["sft", "dpo", "kto"]: - train(command_name) - elif command_name == "chat": - chat() - elif command_name == "env": - print_env() - else: - raise ValueError( - f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}" - ) - - -if __name__ == "__main__": - main() From 70dd253b2b669775dec6b5d9afbe8cc1d2d57726 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 4 Dec 2024 18:34:21 +0000 Subject: [PATCH 19/42] chat init --- examples/scripts/chat.py | 2 +- trl/scripts/chat.py | 49 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 trl/scripts/chat.py diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py index d29200055c..89b0e30e5e 100644 --- a/examples/scripts/chat.py +++ b/examples/scripts/chat.py @@ -29,7 +29,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from trl import TrlParser, init_zero_verbose -from trl.commands.cli_utils import ChatArguments +from trl.scripts.utils import ChatArguments from trl.trainer.utils import get_quantization_config diff --git a/trl/scripts/chat.py b/trl/scripts/chat.py new file mode 100644 index 0000000000..f6118d8409 --- /dev/null +++ b/trl/scripts/chat.py @@ -0,0 +1,49 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time + +from rich.console import Console + + +def chat(): + console = Console() + # Make sure to import things locally to avoid verbose from third party libs. + with console.status("[bold purple]Welcome! Initializing the TRL CLI..."): + time.sleep(1) + ... + # from .utils import init_zero_verbose + + # init_zero_verbose() + # trl_examples_dir = os.path.dirname(__file__) + + # command = f"python {trl_examples_dir}/scripts/chat.py {' '.join(sys.argv[2:])}" + + # try: + # subprocess.run( + # command.split(), + # text=True, + # check=True, + # encoding="utf-8", + # cwd=os.getcwd(), + # env=os.environ.copy(), + # ) + # except (CalledProcessError, ChildProcessError) as exc: + # console.log("TRL - CHAT failed! See the logs above for further details.") + # raise ValueError("TRL CLI failed! Check the traceback above..") from exc + + +if __name__ == "__main__": + chat() From 972f7c662a6ec23aa9fecc5caca28974c55ce3fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 14:03:31 +0000 Subject: [PATCH 20/42] test cli [skip ci] --- tests/test_cli.py | 54 ++++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index d288f0f0bf..6055b66cec 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,34 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import sys import unittest +from io import StringIO +from unittest.mock import patch +from trl.cli import main -class CLITester(unittest.TestCase): - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") - def test_sft_cli(self): - try: - subprocess.run( - "trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine", - shell=True, - check=True, - ) - except BaseException: - self.fail("An error occurred while running the CLI, please double check") - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") - def test_dpo_cli(self): - try: - subprocess.run( - "trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine", - shell=True, - check=True, - ) - except BaseException: - self.fail("An error occurred while running the CLI, please double check") +class TestCLI(unittest.TestCase): + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.argv", ["trl", "env"]) + def test_env(self, mock_stdout): + main() + self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) - def test_env_cli(self): - output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True) - self.assertIn("- Python version: ", output.stdout) + @patch( + "sys.argv", + [ + "trl", + "dpo", + "--output_dir", + "output_dir", + "--model_name_or_path", + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "--dataset_name trl-internal-testing/zen", + "--report_to none", + ], + ) + def test_dpo(self): + main() + + +if __name__ == "__main__": + unittest.main() From 1386d414934721ba29488e0e821136beb4eaa338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 14:28:45 +0000 Subject: [PATCH 21/42] Add `datast_config_name` to `ScriptArguments` (#2440) --- examples/scripts/bco.py | 2 +- examples/scripts/cpo.py | 2 +- examples/scripts/dpo_online.py | 2 +- examples/scripts/dpo_vlm.py | 2 +- examples/scripts/gkd.py | 2 +- examples/scripts/kto.py | 2 +- examples/scripts/nash_md.py | 2 +- examples/scripts/orpo.py | 2 +- examples/scripts/ppo/ppo.py | 4 +++- examples/scripts/ppo/ppo_tldr.py | 2 +- examples/scripts/reward_modeling.py | 2 +- examples/scripts/rloo/rloo.py | 4 +++- examples/scripts/rloo/rloo_tldr.py | 2 +- examples/scripts/sft.py | 2 +- examples/scripts/sft_video_llm.py | 2 +- examples/scripts/sft_vlm.py | 2 +- examples/scripts/sft_vlm_smol_vlm.py | 2 +- examples/scripts/xpo.py | 2 +- trl/scripts/dpo.py | 2 +- trl/scripts/utils.py | 2 ++ 20 files changed, 25 insertions(+), 19 deletions(-) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index d37a1fb3a9..d1ef3ed465 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -126,7 +126,7 @@ def mean_pooling(model_output, attention_mask): if tokenizer.chat_template is None: model, tokenizer = setup_chat_format(model, tokenizer) - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) accelerator = Accelerator() embedding_model = AutoModel.from_pretrained( diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 019e9ef714..4460239d99 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -80,7 +80,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index b38e0bc95b..e280148521 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -120,7 +120,7 @@ if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) trainer = OnlineDPOTrainer( model=model, diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 0781edd9e3..8d9db3527a 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -103,7 +103,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) ################ # Training diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index ed5d07a47e..b397d33250 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -103,7 +103,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) with PartialState().local_main_process_first(): dataset = dataset.map( diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 56205ce57e..7dcd769a81 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -91,7 +91,7 @@ model, tokenizer = setup_chat_format(model, tokenizer) # Load the dataset - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) # Initialize the KTO trainer trainer = KTOTrainer( diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index d4915e0347..680fafce6c 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -120,7 +120,7 @@ if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) trainer = NashMDTrainer( model=model, diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index ac3f598f88..a984842d42 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -80,7 +80,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 3632b1b40d..8508bb4909 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -119,7 +119,9 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config_name, split=script_args.dataset_train_split + ) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 3d8c07ac03..96b08c2f31 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -126,7 +126,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) train_dataset = dataset[script_args.dataset_train_split] eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index ce99964da9..34165c323c 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -106,7 +106,7 @@ ############## # Load dataset ############## - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) ########## # Training diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index a3c685a84b..88491006d0 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -90,7 +90,9 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config_name, split=script_args.dataset_train_split + ) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 36f759208a..4a7c17ce86 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -92,7 +92,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) train_dataset = dataset[script_args.dataset_train_split] eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 422fa89a6f..2b9ce7ee59 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -89,7 +89,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) ################ # Training diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index 3343c3a302..be2706f421 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -176,7 +176,7 @@ class CustomScriptArguments(SFTScriptArguments): training_args.dataset_kwargs = {"skip_prepare_dataset": True} # Load dataset - dataset = load_dataset(script_args.dataset_name, split="train") + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name, split="train") # Setup model torch_dtype = ( diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index 49654bc408..dd2934ca32 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -108,7 +108,7 @@ def collate_fn(examples): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) ################ # Training diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py index 2cac4f2cac..7d08efea24 100644 --- a/examples/scripts/sft_vlm_smol_vlm.py +++ b/examples/scripts/sft_vlm_smol_vlm.py @@ -120,7 +120,7 @@ def collate_fn(examples): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) ################ # Training diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index d2d3cb05f8..6d59e01027 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -105,7 +105,7 @@ if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) trainer = XPOTrainer( model=model, diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 822048cff6..951b2d95bf 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -109,7 +109,7 @@ def main(script_args, training_args, model_config): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) ########## # Training diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index 6dcf4811fd..e4ee5db822 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -40,6 +40,8 @@ class ScriptArguments: Args: dataset_name (`str`): Dataset name. + dataset_config_name (`str` or `None`, *optional*, defaults to `None`): + Dataset configuration name. dataset_train_split (`str`, *optional*, defaults to `"train"`): Dataset split to use for training. dataset_test_split (`str`, *optional*, defaults to `"test"`): From bf289d860c18884324c40e99f8b3f231094c9d0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 14:30:04 +0000 Subject: [PATCH 22/42] add missing arg --- trl/scripts/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index e4ee5db822..2cbb364e8e 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -54,6 +54,7 @@ class ScriptArguments: """ dataset_name: str + dataset_config_name: Optional[str] = None dataset_train_split: str = "train" dataset_test_split: str = "test" gradient_checkpointing_use_reentrant: bool = False From d811b1b7c08e469ff8f0cfef030195a7a6baf139 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 16:41:36 +0000 Subject: [PATCH 23/42] Add test cases for 'trl sft' and 'trl dpo' commands --- tests/test_cli.py | 30 ++++++++++++++---------------- trl/cli.py | 9 ++++----- trl/commands/__init__.py | 29 ----------------------------- trl/commands/cli.py | 14 -------------- trl/scripts/dpo.py | 8 +++----- trl/scripts/utils.py | 16 ++++++++++------ 6 files changed, 31 insertions(+), 75 deletions(-) delete mode 100644 trl/commands/__init__.py delete mode 100644 trl/commands/cli.py diff --git a/tests/test_cli.py b/tests/test_cli.py index 6055b66cec..1646348357 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from io import StringIO from unittest.mock import patch @@ -20,26 +21,23 @@ class TestCLI(unittest.TestCase): @patch("sys.stdout", new_callable=StringIO) - @patch("sys.argv", ["trl", "env"]) def test_env(self, mock_stdout): - main() + command = "trl env" + with patch("sys.argv", command.split(" ")): + main() self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) - @patch( - "sys.argv", - [ - "trl", - "dpo", - "--output_dir", - "output_dir", - "--model_name_or_path", - "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - "--dataset_name trl-internal-testing/zen", - "--report_to none", - ], - ) + def test_sft(self): + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_language_modeling --report_to none" + with patch("sys.argv", command.split(" ")): + main() + def test_dpo(self): - main() + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_preference --report_to none" + with patch("sys.argv", command.split(" ")): + main() if __name__ == "__main__": diff --git a/trl/cli.py b/trl/cli.py index ffc920a433..71e4aeb696 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -20,13 +20,14 @@ from .scripts.dpo import make_parser as make_dpo_parser from .scripts.env import print_env +from .scripts.utils import TrlParser def main(): - parser = argparse.ArgumentParser("TRL CLI", usage="trl", allow_abbrev=False) + parser = argparse.ArgumentParser(prog="TRL CLI", usage="trl", allow_abbrev=False) # Add the subparsers - subparsers = parser.add_subparsers(help="available commands", dest="command") + subparsers = parser.add_subparsers(help="available commands", dest="command", parser_class=TrlParser) # Add the subparsers for every script make_dpo_parser(subparsers) @@ -42,9 +43,7 @@ def main(): # Feed the args to the launch command args.training_script_args = sys.argv[2:] # remove "trl" and "dpo" - - # Launch the training - launch_command(args) + launch_command(args) # launch training elif args.command == "env": print_env() diff --git a/trl/commands/__init__.py b/trl/commands/__init__.py deleted file mode 100644 index c4da312cd9..0000000000 --- a/trl/commands/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -from ..import_utils import OptionalDependencyNotAvailable, _LazyModule - - -_import_structure = { - "cli_utils": ["DPOScriptArguments", "SFTScriptArguments", "TrlParser", "YamlConfigParser", "init_zero_verbose"], -} - -if TYPE_CHECKING: - from .cli_utils import DPOScriptArguments, SFTScriptArguments, TrlParser, YamlConfigParser, init_zero_verbose -else: - import sys - - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/commands/cli.py b/trl/commands/cli.py deleted file mode 100644 index adfc257a24..0000000000 --- a/trl/commands/cli.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 951b2d95bf..20a6ef3f98 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -138,13 +138,11 @@ def main(script_args, training_args, model_config): def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, DPOConfig, ModelConfig) if subparsers is not None: - parser = subparsers.add_parser("dpo", help="Run the DPO training script") + parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types) else: - parser = parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) - - if subparsers is not None: - parser.set_defaults(func=main) + parser = TrlParser(dataclass_types) return parser diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index 2cbb364e8e..82a01867a2 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -191,7 +191,7 @@ class TrlParser(HfArgumentParser): configurations, while also supporting configuration file loading and environment variable management. Args: - dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`): + dataclass_types (`Union[DataClassType, Iterable[DataClassType]]` or `None`, *optional*, defaults to `None`): Dataclass types to use for argument parsing. **kwargs: Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor. @@ -239,16 +239,17 @@ class MyArguments: ) def __init__( self, - dataclass_types: Union[DataClassType, Iterable[DataClassType]], + dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, ignore_extra_args: Optional[bool] = None, **kwargs, ): - super().__init__(dataclass_types=dataclass_types, **kwargs) - self._ignore_extra_args = ignore_extra_args + # Make sure dataclass_types is an iterable + if dataclass_types is None: + dataclass_types = [] + elif not isinstance(dataclass_types, Iterable): + dataclass_types = [dataclass_types] # Check that none of the dataclasses have the "config" field - if not isinstance(dataclass_types, list): - dataclass_types = [dataclass_types] for dataclass_type in dataclass_types: if "config" in dataclass_type.__dataclass_fields__: raise ValueError( @@ -256,6 +257,9 @@ def __init__( f"config file path and should not be used in the dataclass." ) + super().__init__(dataclass_types=dataclass_types, **kwargs) + self._ignore_extra_args = ignore_extra_args + def post_process_dataclasses(self, dataclasses): """ Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments. From 61706afe388f41402c54be876e566b62e11caaa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 17:17:39 +0000 Subject: [PATCH 24/42] Add sft.py script and update cli.py to include sft command --- trl/cli.py | 11 ++++ trl/scripts/sft.py | 125 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 trl/scripts/sft.py diff --git a/trl/cli.py b/trl/cli.py index 71e4aeb696..19d8783ab1 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -20,6 +20,7 @@ from .scripts.dpo import make_parser as make_dpo_parser from .scripts.env import print_env +from .scripts.sft import make_parser as make_sft_parser from .scripts.utils import TrlParser @@ -32,6 +33,7 @@ def main(): # Add the subparsers for every script make_dpo_parser(subparsers) subparsers.add_parser("env", help="Print the environment information") + make_sft_parser(subparsers) # Parse the arguments args = parser.parse_args() @@ -48,6 +50,15 @@ def main(): elif args.command == "env": print_env() + elif args.command == "sft": + # Get the default args for the launch command + sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py") + args = launch_command_parser().parse_args([sft_training_script]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "sft" + launch_command(args) # launch training + if __name__ == "__main__": main() diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py new file mode 100644 index 0000000000..3533121fe1 --- /dev/null +++ b/trl/scripts/sft.py @@ -0,0 +1,125 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +# Full training +python examples/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 100 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub + +# LoRA +python examples/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-4 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 100 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub +""" + +import argparse + +from datasets import load_dataset +from transformers import AutoTokenizer + +from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +def main(script_args, training_args, model_config): + ################ + # Model init kwargs & Tokenizer + ################ + quantization_config = get_quantization_config(model_config) + model_kwargs = dict( + revision=model_config.model_revision, + trust_remote_code=model_config.trust_remote_code, + attn_implementation=model_config.attn_implementation, + torch_dtype=model_config.torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + training_args.model_init_kwargs = model_kwargs + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True + ) + tokenizer.pad_token = tokenizer.eos_token + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model_config.model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_config), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, SFTConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_config = parser.parse_args_and_config() + main(script_args, training_args, model_config) From d9094e240a390254473cd52e244266b636db38ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 18:14:34 +0000 Subject: [PATCH 25/42] Move sft script --- examples/scripts/sft.py | 100 ++-------------------------------------- 1 file changed, 3 insertions(+), 97 deletions(-) diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 2b9ce7ee59..61019a0c19 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -11,101 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -# Full training -python examples/scripts/sft.py \ - --model_name_or_path Qwen/Qwen2-0.5B \ - --dataset_name trl-lib/Capybara \ - --learning_rate 2.0e-5 \ - --num_train_epochs 1 \ - --packing \ - --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 8 \ - --gradient_checkpointing \ - --logging_steps 25 \ - --eval_strategy steps \ - --eval_steps 100 \ - --output_dir Qwen2-0.5B-SFT \ - --push_to_hub -# LoRA -python examples/scripts/sft.py \ - --model_name_or_path Qwen/Qwen2-0.5B \ - --dataset_name trl-lib/Capybara \ - --learning_rate 2.0e-4 \ - --num_train_epochs 1 \ - --packing \ - --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 8 \ - --gradient_checkpointing \ - --logging_steps 25 \ - --eval_strategy steps \ - --eval_steps 100 \ - --use_peft \ - --lora_r 32 \ - --lora_alpha 16 \ - --output_dir Qwen2-0.5B-SFT \ - --push_to_hub -""" - -from datasets import load_dataset -from transformers import AutoTokenizer - -from trl import ( - ModelConfig, - ScriptArguments, - SFTConfig, - SFTTrainer, - TrlParser, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) - - -if __name__ == "__main__": - parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() - - ################ - # Model init kwargs & Tokenizer - ################ - quantization_config = get_quantization_config(model_config) - model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, - torch_dtype=model_config.torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - training_args.model_init_kwargs = model_kwargs - tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True - ) - tokenizer.pad_token = tokenizer.eos_token - - ################ - # Dataset - ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) - - ################ - # Training - ################ - trainer = SFTTrainer( - model=model_config.model_name_or_path, - args=training_args, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - processing_class=tokenizer, - peft_config=get_peft_config(model_config), - ) - - trainer.train() - - # Save and push to hub - trainer.save_model(training_args.output_dir) - if training_args.push_to_hub: - trainer.push_to_hub(dataset_name=script_args.dataset_name) +############################################################################################### +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py # +############################################################################################### \ No newline at end of file From 7d2e62c8a726316f2edcaa72a479ab24406764c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 18:16:07 +0000 Subject: [PATCH 26/42] chat --- examples/scripts/chat.py | 356 +----------------------------- trl/cli.py | 10 +- trl/scripts/chat.py | 453 ++++++++++++++++++++++++++++++++++++--- trl/scripts/utils.py | 67 +----- 4 files changed, 441 insertions(+), 445 deletions(-) diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py index 89b0e30e5e..4340425038 100644 --- a/examples/scripts/chat.py +++ b/examples/scripts/chat.py @@ -13,356 +13,6 @@ # limitations under the License. -import copy -import json -import os -import pwd -import re -import sys -import time -from threading import Thread - -import torch -from rich.console import Console -from rich.live import Live -from rich.markdown import Markdown -from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer - -from trl import TrlParser, init_zero_verbose -from trl.scripts.utils import ChatArguments -from trl.trainer.utils import get_quantization_config - - -init_zero_verbose() - -HELP_STRING = """\ - -**TRL CHAT INTERFACE** - -The chat interface is a simple tool to try out a chat model. - -Besides talking to the model there are several commands: -- **clear**: clears the current conversation and start a new one -- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input -- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';'). -- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set** -- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided -- **exit**: closes the interface -""" - -SUPPORTED_GENERATION_KWARGS = [ - "max_new_tokens", - "do_sample", - "num_beams", - "temperature", - "top_p", - "top_k", - "repetition_penalty", -] - -SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$" - - -class RichInterface: - def __init__(self, model_name=None, user_name=None): - self._console = Console() - if model_name is None: - self.model_name = "assistant" - else: - self.model_name = model_name - if user_name is None: - self.user_name = "user" - else: - self.user_name = user_name - - def stream_output(self, output_stream): - """Stream output from a role.""" - # This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py - # Create a Live context for updating the console output - text = "" - self._console.print(f"[bold blue]<{self.model_name}>:") - with Live(console=self._console, refresh_per_second=4) as live: - # Read lines from the stream - for i, outputs in enumerate(output_stream): - if not outputs or i == 0: - continue - text += outputs - # Render the accumulated text as Markdown - # NOTE: this is a workaround for the rendering "unstandard markdown" - # in rich. The chatbots output treat "\n" as a new line for - # better compatibility with real-world text. However, rendering - # in markdown would break the format. It is because standard markdown - # treat a single "\n" in normal text as a space. - # Our workaround is adding two spaces at the end of each line. - # This is not a perfect solution, as it would - # introduce trailing spaces (only) in code block, but it works well - # especially for console output, because in general the console does not - # care about trailing spaces. - lines = [] - for line in text.splitlines(): - lines.append(line) - if line.startswith("```"): - # Code block marker - do not add trailing spaces, as it would - # break the syntax highlighting - lines.append("\n") - else: - lines.append(" \n") - markdown = Markdown("".join(lines).strip(), code_theme="github-dark") - # Update the Live console output - live.update(markdown) - self._console.print() - return text - - def input(self): - input = self._console.input(f"[bold red]<{self.user_name}>:\n") - self._console.print() - return input - - def clear(self): - self._console.clear() - - def print_user_message(self, text): - self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}") - self._console.print() - - def print_green(self, text): - self._console.print(f"[bold green]{text}") - self._console.print() - - def print_red(self, text): - self._console.print(f"[bold red]{text}") - self._console.print() - - def print_help(self): - self._console.print(Markdown(HELP_STRING)) - self._console.print() - - -def get_username(): - return pwd.getpwuid(os.getuid())[0] - - -def create_default_filename(model_name): - time_str = time.strftime("%Y-%m-%d_%H-%M-%S") - return f"{model_name}/chat_{time_str}.json" - - -def save_chat(chat, args, filename): - output_dict = {} - output_dict["settings"] = vars(args) - output_dict["chat_history"] = chat - - folder = args.save_folder - - if filename is None: - filename = create_default_filename(args.model_name_or_path) - filename = os.path.join(folder, filename) - os.makedirs(os.path.dirname(filename), exist_ok=True) - - with open(filename, "w") as f: - json.dump(output_dict, f, indent=4) - return os.path.abspath(filename) - - -def clear_chat_history(system_prompt): - if system_prompt is None: - chat = [] - else: - chat = [{"role": "system", "content": system_prompt}] - return chat - - -def parse_settings(user_input, current_args, interface): - settings = user_input[4:].strip().split(";") - settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings] - settings = dict(settings) - error = False - - for name in settings: - if hasattr(current_args, name): - try: - if isinstance(getattr(current_args, name), bool): - if settings[name] == "True": - settings[name] = True - elif settings[name] == "False": - settings[name] = False - else: - raise ValueError - else: - settings[name] = type(getattr(current_args, name))(settings[name]) - except ValueError: - interface.print_red( - f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}." - ) - else: - interface.print_red(f"There is no '{name}' setting.") - - if error: - interface.print_red("There was an issue parsing the settings. No settings have been changed.") - return current_args, False - else: - for name in settings: - setattr(current_args, name, settings[name]) - interface.print_green(f"Set {name} to {settings[name]}.") - - time.sleep(1.5) # so the user has time to read the changes - return current_args, True - - -def load_model_and_tokenizer(args): - tokenizer = AutoTokenizer.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - ) - - torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype) - quantization_config = get_quantization_config(args) - model_kwargs = dict( - revision=args.model_revision, - attn_implementation=args.attn_implementation, - torch_dtype=torch_dtype, - device_map="auto", - quantization_config=quantization_config, - ) - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs - ) - - if getattr(model, "hf_device_map", None) is None: - model = model.to(args.device) - - return model, tokenizer - - -def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids): - if tokenizer.pad_token_id is None: - pad_token_id = tokenizer.eos_token_id - else: - pad_token_id = tokenizer.pad_token_id - - all_eos_token_ids = [] - - if eos_tokens is not None: - all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(","))) - - if eos_token_ids is not None: - all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")]) - - if len(all_eos_token_ids) == 0: - all_eos_token_ids.append(tokenizer.eos_token_id) - - return pad_token_id, all_eos_token_ids - - -def chat_cli(): - parser = TrlParser(ChatArguments) - - if "--config" not in sys.argv: - sys.argv.append("--config") - sys.argv.append(os.path.join(os.path.dirname(__file__), "config/default_chat_config.yaml")) - args = parser.parse_args_and_config()[0] - if args.examples is None: - args.examples = {} - - current_args = copy.deepcopy(args) - - if args.user is None: - user = get_username() - else: - user = args.user - - model, tokenizer = load_model_and_tokenizer(args) - generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) - - pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) - - interface = RichInterface(model_name=args.model_name_or_path, user_name=user) - interface.clear() - chat = clear_chat_history(current_args.system_prompt) - while True: - try: - user_input = interface.input() - - if user_input == "clear": - chat = clear_chat_history(current_args.system_prompt) - interface.clear() - continue - - if user_input == "help": - interface.print_help() - continue - - if user_input == "exit": - break - - if user_input == "reset": - interface.clear() - current_args = copy.deepcopy(args) - chat = clear_chat_history(current_args.system_prompt) - continue - - if user_input.startswith("save") and len(user_input.split()) < 2: - split_input = user_input.split() - - if len(split_input) == 2: - filename = split_input[1] - else: - filename = None - filename = save_chat(chat, current_args, filename) - interface.print_green(f"Chat saved in {filename}!") - continue - - if re.match(SETTING_RE, user_input): - current_args, success = parse_settings(user_input, current_args, interface) - if success: - chat = [] - interface.clear() - continue - - if user_input.startswith("example") and len(user_input.split()) == 2: - example_name = user_input.split()[1] - if example_name in current_args.examples: - interface.clear() - chat = [] - interface.print_user_message(current_args.examples[example_name]["text"]) - user_input = current_args.examples[example_name]["text"] - else: - interface.print_red( - f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}." - ) - continue - - chat.append({"role": "user", "content": user_input}) - - inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( - model.device - ) - attention_mask = torch.ones_like(inputs) - generation_kwargs = dict( - inputs=inputs, - attention_mask=attention_mask, - streamer=generation_streamer, - max_new_tokens=current_args.max_new_tokens, - do_sample=current_args.do_sample, - num_beams=current_args.num_beams, - temperature=current_args.temperature, - top_k=current_args.top_k, - top_p=current_args.top_p, - repetition_penalty=current_args.repetition_penalty, - pad_token_id=pad_token_id, - eos_token_id=eos_token_ids, - ) - - thread = Thread(target=model.generate, kwargs=generation_kwargs) - thread.start() - model_output = interface.stream_output(generation_streamer) - thread.join() - chat.append({"role": "assistant", "content": model_output}) - - except KeyboardInterrupt: - break - - -if __name__ == "__main__": - chat_cli() +################################################################################################ +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/chat.py # +################################################################################################ \ No newline at end of file diff --git a/trl/cli.py b/trl/cli.py index 19d8783ab1..a2283d3e88 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import os import sys from accelerate.commands.launch import launch_command, launch_command_parser +from .scripts.chat import main as chat_main +from .scripts.chat import make_parser as make_chat_parser from .scripts.dpo import make_parser as make_dpo_parser from .scripts.env import print_env from .scripts.sft import make_parser as make_sft_parser @@ -25,12 +26,13 @@ def main(): - parser = argparse.ArgumentParser(prog="TRL CLI", usage="trl", allow_abbrev=False) + parser = TrlParser(prog="TRL CLI", usage="trl", allow_abbrev=False) # Add the subparsers subparsers = parser.add_subparsers(help="available commands", dest="command", parser_class=TrlParser) # Add the subparsers for every script + make_chat_parser(subparsers) make_dpo_parser(subparsers) subparsers.add_parser("env", help="Print the environment information") make_sft_parser(subparsers) @@ -38,6 +40,10 @@ def main(): # Parse the arguments args = parser.parse_args() + if args.command == "chat": + (chat_args,) = parser.parse_args_and_config() + chat_main(chat_args) + if args.command == "dpo": # Get the default args for the launch command dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py") diff --git a/trl/scripts/chat.py b/trl/scripts/chat.py index f6118d8409..d1746441ea 100644 --- a/trl/scripts/chat.py +++ b/trl/scripts/chat.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,37 +13,442 @@ # limitations under the License. +import argparse +import copy +import json +import os +import pwd +import re import time +from dataclasses import dataclass, field +from threading import Thread +import torch +import yaml from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer +from trl import TrlParser, init_zero_verbose +from trl.trainer.utils import get_quantization_config -def chat(): - console = Console() - # Make sure to import things locally to avoid verbose from third party libs. - with console.status("[bold purple]Welcome! Initializing the TRL CLI..."): - time.sleep(1) - ... - # from .utils import init_zero_verbose - # init_zero_verbose() - # trl_examples_dir = os.path.dirname(__file__) +init_zero_verbose() - # command = f"python {trl_examples_dir}/scripts/chat.py {' '.join(sys.argv[2:])}" +HELP_STRING = """\ - # try: - # subprocess.run( - # command.split(), - # text=True, - # check=True, - # encoding="utf-8", - # cwd=os.getcwd(), - # env=os.environ.copy(), - # ) - # except (CalledProcessError, ChildProcessError) as exc: - # console.log("TRL - CHAT failed! See the logs above for further details.") - # raise ValueError("TRL CLI failed! Check the traceback above..") from exc +**TRL CHAT INTERFACE** + +The chat interface is a simple tool to try out a chat model. + +Besides talking to the model there are several commands: +- **clear**: clears the current conversation and start a new one +- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input +- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';'). +- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set** +- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided +- **exit**: closes the interface +""" + +SUPPORTED_GENERATION_KWARGS = [ + "max_new_tokens", + "do_sample", + "num_beams", + "temperature", + "top_p", + "top_k", + "repetition_penalty", +] + +SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$" + + +DEFAULT_EXAMPLES = { + "llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"}, + "code": { + "text": "Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]." + }, + "helicopter": {"text": "How many helicopters can a human eat in one sitting?"}, + "numbers": {"text": "Count to 10 but skip every number ending with an 'e'"}, + "birds": {"text": "Why aren't birds real?"}, + "socks": {"text": "Why is it important to eat socks after meditating?"}, +} + + +@dataclass +class ChatArguments: + # general settings + model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"}) + user: str = field(default=None, metadata={"help": "Username to display in chat interface"}) + system_prompt: str = field(default=None, metadata={"help": "System prompt"}) + save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"}) + device: str = field( + default="cpu", + metadata={"help": "device to use for inference."}, + ) + examples_path: str = field(default=None, metadata={"help": "Path to a yaml file with examples"}) + # generation settings + max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"}) + do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation"}) + num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search"}) + temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation"}) + top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"}) + top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"}) + repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"}) + eos_tokens: str = field( + default=None, + metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated"}, + ) + eos_token_ids: str = field( + default=None, + metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated"}, + ) + # model loading + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + torch_dtype: str = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) + attn_implementation: str = field( + default=None, + metadata={ + "help": ( + "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" + ) + }, + ) + load_in_8bit: bool = field( + default=False, + metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}, + ) + load_in_4bit: bool = field( + default=False, + metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}, + ) + + bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) + use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) + + +class RichInterface: + def __init__(self, model_name=None, user_name=None): + self._console = Console() + if model_name is None: + self.model_name = "assistant" + else: + self.model_name = model_name + if user_name is None: + self.user_name = "user" + else: + self.user_name = user_name + + def stream_output(self, output_stream): + """Stream output from a role.""" + # This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py + # Create a Live context for updating the console output + text = "" + self._console.print(f"[bold blue]<{self.model_name}>:") + with Live(console=self._console, refresh_per_second=4) as live: + # Read lines from the stream + for i, outputs in enumerate(output_stream): + if not outputs or i == 0: + continue + text += outputs + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines).strip(), code_theme="github-dark") + # Update the Live console output + live.update(markdown) + self._console.print() + return text + + def input(self): + input = self._console.input(f"[bold red]<{self.user_name}>:\n") + self._console.print() + return input + + def clear(self): + self._console.clear() + + def print_user_message(self, text): + self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}") + self._console.print() + + def print_green(self, text): + self._console.print(f"[bold green]{text}") + self._console.print() + + def print_red(self, text): + self._console.print(f"[bold red]{text}") + self._console.print() + + def print_help(self): + self._console.print(Markdown(HELP_STRING)) + self._console.print() + + +def get_username(): + return pwd.getpwuid(os.getuid())[0] + + +def create_default_filename(model_name): + time_str = time.strftime("%Y-%m-%d_%H-%M-%S") + return f"{model_name}/chat_{time_str}.json" + + +def save_chat(chat, args, filename): + output_dict = {} + output_dict["settings"] = vars(args) + output_dict["chat_history"] = chat + + folder = args.save_folder + + if filename is None: + filename = create_default_filename(args.model_name_or_path) + filename = os.path.join(folder, filename) + os.makedirs(os.path.dirname(filename), exist_ok=True) + + with open(filename, "w") as f: + json.dump(output_dict, f, indent=4) + return os.path.abspath(filename) + + +def clear_chat_history(system_prompt): + if system_prompt is None: + chat = [] + else: + chat = [{"role": "system", "content": system_prompt}] + return chat + + +def parse_settings(user_input, current_args, interface): + settings = user_input[4:].strip().split(";") + settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings] + settings = dict(settings) + error = False + + for name in settings: + if hasattr(current_args, name): + try: + if isinstance(getattr(current_args, name), bool): + if settings[name] == "True": + settings[name] = True + elif settings[name] == "False": + settings[name] = False + else: + raise ValueError + else: + settings[name] = type(getattr(current_args, name))(settings[name]) + except ValueError: + interface.print_red( + f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}." + ) + else: + interface.print_red(f"There is no '{name}' setting.") + + if error: + interface.print_red("There was an issue parsing the settings. No settings have been changed.") + return current_args, False + else: + for name in settings: + setattr(current_args, name, settings[name]) + interface.print_green(f"Set {name} to {settings[name]}.") + + time.sleep(1.5) # so the user has time to read the changes + return current_args, True + + +def load_model_and_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, + revision=args.model_revision, + trust_remote_code=args.trust_remote_code, + ) + + torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype) + quantization_config = get_quantization_config(args) + model_kwargs = dict( + revision=args.model_revision, + attn_implementation=args.attn_implementation, + torch_dtype=torch_dtype, + device_map="auto", + quantization_config=quantization_config, + ) + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs + ) + + if getattr(model, "hf_device_map", None) is None: + model = model.to(args.device) + + return model, tokenizer + + +def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids): + if tokenizer.pad_token_id is None: + pad_token_id = tokenizer.eos_token_id + else: + pad_token_id = tokenizer.pad_token_id + + all_eos_token_ids = [] + + if eos_tokens is not None: + all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(","))) + + if eos_token_ids is not None: + all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")]) + + if len(all_eos_token_ids) == 0: + all_eos_token_ids.append(tokenizer.eos_token_id) + + return pad_token_id, all_eos_token_ids + + +def main(args: ChatArguments): + if args.examples_path is None: + examples = DEFAULT_EXAMPLES + else: + with open(args.examples_path) as f: + examples = yaml.safe_load(f) + + current_args = copy.deepcopy(args) + + if args.user is None: + user = get_username() + else: + user = args.user + + model, tokenizer = load_model_and_tokenizer(args) + generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) + + pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) + + interface = RichInterface(model_name=args.model_name_or_path, user_name=user) + interface.clear() + chat = clear_chat_history(current_args.system_prompt) + while True: + try: + user_input = interface.input() + + if user_input == "clear": + chat = clear_chat_history(current_args.system_prompt) + interface.clear() + continue + + if user_input == "help": + interface.print_help() + continue + + if user_input == "exit": + break + + if user_input == "reset": + interface.clear() + current_args = copy.deepcopy(args) + chat = clear_chat_history(current_args.system_prompt) + continue + + if user_input.startswith("save") and len(user_input.split()) < 2: + split_input = user_input.split() + + if len(split_input) == 2: + filename = split_input[1] + else: + filename = None + filename = save_chat(chat, current_args, filename) + interface.print_green(f"Chat saved in {filename}!") + continue + + if re.match(SETTING_RE, user_input): + current_args, success = parse_settings(user_input, current_args, interface) + if success: + chat = [] + interface.clear() + continue + + if user_input.startswith("example") and len(user_input.split()) == 2: + example_name = user_input.split()[1] + if example_name in examples: + interface.clear() + chat = [] + interface.print_user_message(examples[example_name]["text"]) + user_input = examples[example_name]["text"] + else: + interface.print_red( + f"Example {example_name} not found in list of available examples: {list(examples.keys())}." + ) + continue + + chat.append({"role": "user", "content": user_input}) + + inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( + model.device + ) + attention_mask = torch.ones_like(inputs) + generation_kwargs = dict( + inputs=inputs, + attention_mask=attention_mask, + streamer=generation_streamer, + max_new_tokens=current_args.max_new_tokens, + do_sample=current_args.do_sample, + num_beams=current_args.num_beams, + temperature=current_args.temperature, + top_k=current_args.top_k, + top_p=current_args.top_p, + repetition_penalty=current_args.repetition_penalty, + pad_token_id=pad_token_id, + eos_token_id=eos_token_ids, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + model_output = interface.stream_output(generation_streamer) + thread.join() + chat.append({"role": "assistant", "content": model_output}) + + except KeyboardInterrupt: + break + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ChatArguments,) + if subparsers is not None: + parser = subparsers.add_parser("chat", help=HELP_STRING, dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser if __name__ == "__main__": - chat() + parser = make_parser() + (chat_args,) = parser.parse_args_and_config() + main(chat_args) diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index 82a01867a2..e7de20ef12 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -20,7 +20,7 @@ import subprocess import sys import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Iterable, Optional, Union import yaml @@ -120,71 +120,6 @@ def warning_handler(message, category, filename, lineno, file=None, line=None): warnings.showwarning = warning_handler -@dataclass -class ChatArguments: - # general settings - model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"}) - user: str = field(default=None, metadata={"help": "Username to display in chat interface"}) - system_prompt: str = field(default=None, metadata={"help": "System prompt"}) - save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"}) - device: str = field( - default="cpu", - metadata={"help": "device to use for inference."}, - ) - examples: str = field(default=None, metadata={"help": "Empty placeholder needs to be set via config."}) - # generation settings - max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"}) - do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation"}) - num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search"}) - temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation"}) - top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"}) - top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"}) - repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"}) - eos_tokens: str = field( - default=None, - metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated"}, - ) - eos_token_ids: str = field( - default=None, - metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated"}, - ) - # model loading - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - torch_dtype: str = field( - default=None, - metadata={ - "help": ( - "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " - "dtype will be automatically derived from the model's weights." - ), - "choices": ["auto", "bfloat16", "float16", "float32"], - }, - ) - trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) - attn_implementation: str = field( - default=None, - metadata={ - "help": ( - "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" - ) - }, - ) - load_in_8bit: bool = field( - default=False, - metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}, - ) - load_in_4bit: bool = field( - default=False, - metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}, - ) - - bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) - use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) - - class TrlParser(HfArgumentParser): """ A subclass of [`transformers.HfArgumentParser`] designed for parsing command-line arguments with dataclass-backed From d4685451811644a7de352ab794a5ee861837e4fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 18:17:02 +0000 Subject: [PATCH 27/42] style [ci skip] --- examples/scripts/chat.py | 2 +- examples/scripts/sft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py index 4340425038..db7efd1753 100644 --- a/examples/scripts/chat.py +++ b/examples/scripts/chat.py @@ -15,4 +15,4 @@ ################################################################################################ # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/chat.py # -################################################################################################ \ No newline at end of file +################################################################################################ diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 61019a0c19..54fc8632ab 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -14,4 +14,4 @@ ############################################################################################### # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py # -############################################################################################### \ No newline at end of file +############################################################################################### From 93d423c097fe9e5bd8b4e7d286cacb742b8f5641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 18:29:09 +0000 Subject: [PATCH 28/42] kto --- tests/test_cli.py | 15 ++++-- trl/cli.py | 11 ++++ trl/scripts/kto.py | 128 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 4 deletions(-) create mode 100644 trl/scripts/kto.py diff --git a/tests/test_cli.py b/tests/test_cli.py index 1646348357..d2bd5f8e7a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest from io import StringIO @@ -20,6 +21,12 @@ class TestCLI(unittest.TestCase): + def test_dpo(self): + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_preference --report_to none" + with patch("sys.argv", command.split(" ")): + main() + @patch("sys.stdout", new_callable=StringIO) def test_env(self, mock_stdout): command = "trl env" @@ -27,15 +34,15 @@ def test_env(self, mock_stdout): main() self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) - def test_sft(self): + def test_kto(self): with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory - command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_language_modeling --report_to none" + command = f"trl kto --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_unpaired_preference --report_to none" with patch("sys.argv", command.split(" ")): main() - def test_dpo(self): + def test_sft(self): with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory - command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_preference --report_to none" + command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_language_modeling --report_to none" with patch("sys.argv", command.split(" ")): main() diff --git a/trl/cli.py b/trl/cli.py index a2283d3e88..1e02497927 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -21,6 +21,7 @@ from .scripts.chat import make_parser as make_chat_parser from .scripts.dpo import make_parser as make_dpo_parser from .scripts.env import print_env +from .scripts.kto import make_parser as make_kto_parser from .scripts.sft import make_parser as make_sft_parser from .scripts.utils import TrlParser @@ -35,6 +36,7 @@ def main(): make_chat_parser(subparsers) make_dpo_parser(subparsers) subparsers.add_parser("env", help="Print the environment information") + make_kto_parser(subparsers) make_sft_parser(subparsers) # Parse the arguments @@ -56,6 +58,15 @@ def main(): elif args.command == "env": print_env() + elif args.command == "kto": + # Get the default args for the launch command + kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py") + args = launch_command_parser().parse_args([kto_training_script]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "kto" + launch_command(args) # launch training + elif args.command == "sft": # Get the default args for the launch command sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py") diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py new file mode 100644 index 0000000000..7a2c60551b --- /dev/null +++ b/trl/scripts/kto.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. + +# Full training: +python examples/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 16 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --bf16 \ + --logging_first_step + +# QLoRA: +python examples/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model-lora \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --use_peft \ + --load_in_4bit \ + --lora_target_modules=all-linear \ + --lora_r=16 \ + --lora_alpha=16 +""" + +import argparse + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import ( + KTOConfig, + KTOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_peft_config, + setup_chat_format, +) + + +def main(script_args, training_args, model_args): + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # If we are aligning a base model, we use ChatML as the default template + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer) + + # Load the dataset + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + + # Initialize the KTO trainer + trainer = KTOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # Train and push the model to the Hub + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, KTOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_config = parser.parse_args_and_config() + main(script_args, training_args, model_config) From 9ee485a9bc3d1016eed5a5ab513bacb3f7486e5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 18:29:19 +0000 Subject: [PATCH 29/42] rm example config --- examples/scripts/config/default_chat_config.yaml | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 examples/scripts/config/default_chat_config.yaml diff --git a/examples/scripts/config/default_chat_config.yaml b/examples/scripts/config/default_chat_config.yaml deleted file mode 100644 index 93195f9d7d..0000000000 --- a/examples/scripts/config/default_chat_config.yaml +++ /dev/null @@ -1,13 +0,0 @@ -examples: - llama: - text: There is a Llama in my lawn, how can I get rid of it? - code: - text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]. - helicopter: - text: How many helicopters can a human eat in one sitting? - numbers: - text: Count to 10 but skip every number ending with an 'e' - birds: - text: Why aren't birds real? - socks: - text: Why is it important to eat socks after meditating? \ No newline at end of file From 5f86e61b6e0de88100b778ad58289ab7d948d227 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 18:31:35 +0000 Subject: [PATCH 30/42] first step on doc --- docs/source/clis.mdx | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/source/clis.mdx b/docs/source/clis.mdx index 0e600e5d99..016b7bc647 100644 --- a/docs/source/clis.mdx +++ b/docs/source/clis.mdx @@ -4,8 +4,14 @@ You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SF Currently supported CLIs are: -- `trl sft`: fine-tune a LLM on a text/instruction dataset -- `trl dpo`: fine-tune a LLM with DPO on a preference dataset +#### Training commands + +- `trl dpo`: fine-tune a LLM with DPO +- `trl kto`: fine-tune a LLM with KTO +- `trl sft`: fine-tune a LLM with SFT + +#### Other commands + - `trl chat`: quickly spin up a LLM fine-tuned for chatting - `trl env`: get the system information From 779062b8ebc6f593144e7b1c0a8b2a81c3514c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 18:59:24 +0000 Subject: [PATCH 31/42] see #2442 --- docs/source/sft_trainer.mdx | 22 ++++++++--------- examples/scripts/cpo.py | 8 +++---- examples/scripts/dpo_online.py | 24 +++++++++---------- examples/scripts/dpo_vlm.py | 26 +++++++++------------ examples/scripts/gkd.py | 30 ++++++++++++------------ examples/scripts/nash_md.py | 22 +++++++---------- examples/scripts/orpo.py | 8 +++---- examples/scripts/ppo/ppo.py | 26 +++++++++------------ examples/scripts/ppo/ppo_tldr.py | 26 +++++++++------------ examples/scripts/reward_modeling.py | 18 +++++++------- examples/scripts/rloo/rloo.py | 12 ++++------ examples/scripts/rloo/rloo_tldr.py | 12 ++++------ examples/scripts/sft_video_llm.py | 14 +++++------ examples/scripts/sft_vlm.py | 18 +++++++------- examples/scripts/sft_vlm_smol_vlm.py | 18 +++++++------- examples/scripts/xpo.py | 22 +++++++---------- tests/test_utils.py | 8 +++---- trl/scripts/dpo.py | 24 +++++++++---------- trl/scripts/kto.py | 4 ++-- trl/scripts/sft.py | 22 ++++++++--------- trl/trainer/utils.py | 35 +++++++++++++++------------- 21 files changed, 183 insertions(+), 216 deletions(-) diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index 5b7827fe26..c45069d18c 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -468,30 +468,30 @@ We included a utility function to create your model. ```python from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config -model_config = ModelConfig( +model_args = ModelConfig( model_name_or_path="facebook/opt-350m" attn_implementation=None, # or "flash_attention_2" ) torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) ) -quantization_config = get_quantization_config(model_config) +quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) -model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs) +model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) trainer = SFTTrainer( ..., - model=model_config.model_name_or_path, - peft_config=get_peft_config(model_config), + model=model_args.model_name_or_path, + peft_config=get_peft_config(model_args), ) ``` diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 4460239d99..1eb9c1cc65 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -63,16 +63,16 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() ################ # Model & Tokenizer ################ model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -93,7 +93,7 @@ train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) # train and save the model diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index e280148521..729598aeb6 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -64,18 +64,16 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -83,19 +81,19 @@ ) model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) if training_args.reward_model_path is not None: reward_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, num_labels=1, - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) reward_tokenizer = AutoTokenizer.from_pretrained( training_args.reward_model_path, - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, truncation=True, truncation_side="left", # since we judge the completion, truncating left is more appropriate ) @@ -110,9 +108,9 @@ judge = None tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, + model_args.model_name_or_path, padding_side="left", - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) if tokenizer.chat_template is None: @@ -131,7 +129,7 @@ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, reward_processing_class=reward_tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) if training_args.eval_strategy != "no": diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 8d9db3527a..395c8bf5d5 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -44,43 +44,39 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() ################ # Model & Tokenizer ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) model = AutoModelForVision2Seq.from_pretrained( - model_config.model_name_or_path, - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) - peft_config = get_peft_config(model_config) + peft_config = get_peft_config(model_args) if peft_config is None: ref_model = AutoModelForVision2Seq.from_pretrained( - model_config.model_name_or_path, - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) else: ref_model = None processor = AutoProcessor.from_pretrained( - model_config.model_name_or_path, - trust_remote_code=model_config.trust_remote_code, - do_image_splitting=False, + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, do_image_splitting=False ) tokenizer = processor.tokenizer diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index b397d33250..84ac02228a 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -63,17 +63,17 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() ################ # Model & Tokenizer ################ - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, - torch_dtype=model_config.torch_dtype, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, @@ -81,10 +81,10 @@ training_args.model_init_kwargs = model_kwargs teacher_model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, - torch_dtype=model_config.torch_dtype, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, use_cache=True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, @@ -92,9 +92,9 @@ training_args.teacher_model_init_kwargs = teacher_model_kwargs tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, padding_side="left", ) if tokenizer.pad_token is None: @@ -117,13 +117,13 @@ # Training ################ trainer = GKDTrainer( - model=model_config.model_name_or_path, + model=model_args.model_name_or_path, teacher_model=training_args.teacher_model_name_or_path, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) if training_args.eval_strategy != "no": diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 680fafce6c..4d0c0f29f0 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -69,18 +69,16 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -88,17 +86,17 @@ ) model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ref_model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) if training_args.reward_model_path is not None: reward_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, num_labels=1, - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) else: @@ -111,9 +109,7 @@ judge = None tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index a984842d42..19dbf12e51 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -63,16 +63,16 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() ################ # Model & Tokenizer ################ model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -93,7 +93,7 @@ train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) # train and save the model diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 8508bb4909..8addc0f979 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -69,7 +69,7 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -77,41 +77,37 @@ # Model & Tokenizer ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) - peft_config = get_peft_config(model_config) + peft_config = get_peft_config(model_args) if peft_config is None: ref_policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) else: ref_policy = None diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 96b08c2f31..d7864d3d01 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -76,7 +76,7 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -84,41 +84,37 @@ # Model & Tokenizer ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) - peft_config = get_peft_config(model_config) + peft_config = get_peft_config(model_args) if peft_config is None: ref_policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) else: ref_policy = None diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 34165c323c..81658685fe 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -64,30 +64,28 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) ################ # Model & Tokenizer ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, + revision=model_args.model_revision, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, use_cache=False if training_args.gradient_checkpointing else True, torch_dtype=torch_dtype, ) tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) model = AutoModelForSequenceClassification.from_pretrained( - model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) # Align padding tokens between tokenizer and model model.config.pad_token_id = tokenizer.pad_token_id @@ -96,7 +94,7 @@ if tokenizer.chat_template is None: model, tokenizer = setup_chat_format(model, tokenizer) - if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS": + if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS": warnings.warn( "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs" " Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.", @@ -117,7 +115,7 @@ args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 88491006d0..5eafb0de50 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -63,7 +63,7 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -71,21 +71,19 @@ # Model & Tokenizer ################ tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) ref_policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) ################ # Dataset diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 4a7c17ce86..4bda783690 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -65,7 +65,7 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -73,21 +73,19 @@ # Model & Tokenizer ################ tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) ref_policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) policy = AutoModelForCausalLM.from_pretrained( - training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code ) ################ # Dataset diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index be2706f421..a3381f94fc 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -168,7 +168,7 @@ class CustomScriptArguments(SFTScriptArguments): if __name__ == "__main__": # Parse arguments parser = TrlParser((CustomScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() # Configure training args training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) @@ -180,9 +180,7 @@ class CustomScriptArguments(SFTScriptArguments): # Setup model torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) # Quantization configuration for 4-bit training @@ -195,14 +193,14 @@ class CustomScriptArguments(SFTScriptArguments): # Model initialization model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, torch_dtype=torch_dtype, device_map=get_kbit_device_map(), quantization_config=bnb_config, ) - model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) + model = AutoModelForVision2Seq.from_pretrained(model_args.model_name_or_path, **model_kwargs) peft_config = LoraConfig( task_type="CAUSAL_LM", @@ -220,7 +218,7 @@ class CustomScriptArguments(SFTScriptArguments): model.enable_input_require_grads() processor = AutoProcessor.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) # Prepare dataset diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index dd2934ca32..f7b9befc9e 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -52,7 +52,7 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} @@ -61,24 +61,22 @@ # Model, Tokenizer & Processor ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) processor = AutoProcessor.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) model = AutoModelForVision2Seq.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ################ @@ -120,7 +118,7 @@ def collate_fn(examples): train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=processor.tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py index 7d08efea24..a88e7f552d 100644 --- a/examples/scripts/sft_vlm_smol_vlm.py +++ b/examples/scripts/sft_vlm_smol_vlm.py @@ -59,7 +59,7 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} @@ -68,24 +68,22 @@ # Model, Tokenizer & Processor ################ torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) processor = AutoProcessor.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) model = AutoModelForVision2Seq.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ################ @@ -132,7 +130,7 @@ def collate_fn(examples): train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=processor.tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 6d59e01027..c3d8816546 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -54,18 +54,16 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, XPOConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_and_config() + script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -73,17 +71,17 @@ ) model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) ref_model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) if training_args.reward_model_path is not None: reward_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, num_labels=1, - trust_remote_code=model_config.trust_remote_code, + trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) else: @@ -96,9 +94,7 @@ judge = None tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=model_config.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/tests/test_utils.py b/tests/test_utils.py index 4d26819058..d95cc7f2d6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -91,8 +91,8 @@ def test_pad_2_dim_right_multidim(self): class TestGetPEFTConfig(unittest.TestCase): def test_create_peft_config_use_peft_false(self): """Test that when use_peft is False, the function returns None.""" - model_config = ModelConfig(use_peft=False) - peft_config = get_peft_config(model_config) + model_args = ModelConfig(use_peft=False) + peft_config = get_peft_config(model_args) self.assertIsNone(peft_config) def test_create_peft_config_use_peft_true(self): @@ -107,8 +107,8 @@ def test_create_peft_config_use_peft_true(self): "lora_target_modules": ["up_proj", "down_proj"], "lora_modules_to_save": ["up_proj"], } - model_config = ModelConfig(use_peft=True, **peft_kwargs) - peft_config = get_peft_config(model_config) + model_args = ModelConfig(use_peft=True, **peft_kwargs) + peft_config = get_peft_config(model_args) self.assertTrue(isinstance(peft_config, LoraConfig)) for arg, value in peft_kwargs.items(): # Test that lists of modules are converted to sets diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 20a6ef3f98..382492b5b6 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -65,36 +65,34 @@ from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -def main(script_args, training_args, model_config): +def main(script_args, training_args, model_args): ################ # Model & Tokenizer ################### torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - attn_implementation=model_config.attn_implementation, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) - peft_config = get_peft_config(model_config) + peft_config = get_peft_config(model_args) if peft_config is None: ref_model = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) else: ref_model = None tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -148,5 +146,5 @@ def make_parser(subparsers: argparse._SubParsersAction = None): if __name__ == "__main__": parser = make_parser() - script_args, training_args, model_config = parser.parse_args_and_config() - main(script_args, training_args, model_config) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py index 7a2c60551b..a8aaf3c848 100644 --- a/trl/scripts/kto.py +++ b/trl/scripts/kto.py @@ -124,5 +124,5 @@ def make_parser(subparsers: argparse._SubParsersAction = None): if __name__ == "__main__": parser = make_parser() - script_args, training_args, model_config = parser.parse_args_and_config() - main(script_args, training_args, model_config) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index 3533121fe1..2ba8958081 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -65,23 +65,23 @@ ) -def main(script_args, training_args, model_config): +def main(script_args, training_args, model_args): ################ # Model init kwargs & Tokenizer ################ - quantization_config = get_quantization_config(model_config) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_config.model_revision, - trust_remote_code=model_config.trust_remote_code, - attn_implementation=model_config.attn_implementation, - torch_dtype=model_config.torch_dtype, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) training_args.model_init_kwargs = model_kwargs tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) tokenizer.pad_token = tokenizer.eos_token @@ -94,12 +94,12 @@ def main(script_args, training_args, model_config): # Training ################ trainer = SFTTrainer( - model=model_config.model_name_or_path, + model=model_args.model_name_or_path, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=tokenizer, - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() @@ -121,5 +121,5 @@ def make_parser(subparsers: argparse._SubParsersAction = None): if __name__ == "__main__": parser = make_parser() - script_args, training_args, model_config = parser.parse_args_and_config() - main(script_args, training_args, model_config) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index d1cc3a0e9d..02c8c2bee1 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -47,6 +47,7 @@ is_torch_npu_available, is_torch_xpu_available, ) +from transformers.utils.deprecation import deprecate_kwarg from ..import_utils import is_unsloth_available from ..trainer.model_config import ModelConfig @@ -870,16 +871,17 @@ def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None): return kwargs -def get_quantization_config(model_config: ModelConfig) -> Optional[BitsAndBytesConfig]: - if model_config.load_in_4bit: +@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True) +def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]: + if model_args.load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, - bnb_4bit_compute_dtype=model_config.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype` - bnb_4bit_quant_type=model_config.bnb_4bit_quant_type, - bnb_4bit_use_double_quant=model_config.use_bnb_nested_quant, - bnb_4bit_quant_storage=model_config.torch_dtype, + bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype` + bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, + bnb_4bit_quant_storage=model_args.torch_dtype, ) - elif model_config.load_in_8bit: + elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_8bit=True, ) @@ -898,8 +900,9 @@ def get_kbit_device_map() -> Optional[dict[str, int]]: return None -def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]": - if model_config.use_peft is False: +@deprecate_kwarg("model_config", "0.14.0", "model_args", warn_if_greater_or_equal_version=True) +def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]": + if model_args.use_peft is False: return None if not is_peft_available(): @@ -909,14 +912,14 @@ def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]": ) peft_config = LoraConfig( - task_type=model_config.lora_task_type, - r=model_config.lora_r, - target_modules=model_config.lora_target_modules, - lora_alpha=model_config.lora_alpha, - lora_dropout=model_config.lora_dropout, + task_type=model_args.lora_task_type, + r=model_args.lora_r, + target_modules=model_args.lora_target_modules, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, bias="none", - use_rslora=model_config.use_rslora, - modules_to_save=model_config.lora_modules_to_save, + use_rslora=model_args.use_rslora, + modules_to_save=model_args.lora_modules_to_save, ) return peft_config From 0892264f28d71d781f21a5be9bb1c698c56cc63c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 19:12:18 +0000 Subject: [PATCH 32/42] see #2443 --- trl/scripts/chat.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/trl/scripts/chat.py b/trl/scripts/chat.py index d1746441ea..2ede78e6cb 100644 --- a/trl/scripts/chat.py +++ b/trl/scripts/chat.py @@ -17,6 +17,7 @@ import copy import json import os +import platform import pwd import re import time @@ -34,6 +35,9 @@ from trl.trainer.utils import get_quantization_config +if platform.system() != "Windows": + import pwd + init_zero_verbose() HELP_STRING = """\ @@ -217,7 +221,10 @@ def print_help(self): def get_username(): - return pwd.getpwuid(os.getuid())[0] + if platform.system() == "Windows": + return os.getlogin() + else: + return pwd.getpwuid(os.getuid()).pw_name def create_default_filename(model_name): From 746baec75931a0ed4e24d62ed7ef06a129e83101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Dec 2024 19:50:49 +0000 Subject: [PATCH 33/42] fix chat windows --- trl/scripts/chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/scripts/chat.py b/trl/scripts/chat.py index 2ede78e6cb..fa9eebc44e 100644 --- a/trl/scripts/chat.py +++ b/trl/scripts/chat.py @@ -18,7 +18,6 @@ import json import os import platform -import pwd import re import time from dataclasses import dataclass, field From 2fc0b6f5e3da0c68c57f6dd618f40f4394701938 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:40:00 +0100 Subject: [PATCH 34/42] =?UTF-8?q?=C2=A9=EF=B8=8F=20Copyrights=20update=20(?= =?UTF-8?q?#2454)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First changes * Other files * Finally * rm comment * fix nashmd * Fix example * Fix example [ci skip] --- examples/datasets/hh-rlhf-helpful-base.py | 2 +- .../lm-human-preferences-descriptiveness.py | 2 +- .../lm-human-preferences-sentiment.py | 2 +- examples/datasets/prm800k.py | 2 +- examples/datasets/rlaif-v.py | 2 +- examples/datasets/tldr.py | 2 +- examples/datasets/tldr_preference.py | 2 +- examples/datasets/tokenize_ds.py | 2 +- examples/datasets/ultrafeedback-prompt.py | 2 +- examples/datasets/ultrafeedback.py | 2 +- .../stack_llama/scripts/merge_peft_adapter.py | 2 +- .../stack_llama/scripts/reward_modeling.py | 2 +- .../stack_llama/scripts/rl_training.py | 3 +- .../scripts/supervised_finetuning.py | 2 +- .../stack_llama_2/scripts/dpo_llama2.py | 2 +- .../stack_llama_2/scripts/sft_llama2.py | 2 +- .../research_projects/tools/calculator.py | 2 +- .../tools/python_interpreter.py | 2 +- examples/research_projects/tools/triviaqa.py | 2 +- .../toxicity/scripts/evaluate-toxicity.py | 2 +- .../toxicity/scripts/gpt-j-6b-toxicity.py | 3 +- examples/scripts/alignprop.py | 3 +- examples/scripts/bco.py | 2 +- examples/scripts/chat.py | 1 - examples/scripts/cpo.py | 3 +- examples/scripts/ddpo.py | 3 +- examples/scripts/dpo.py | 3 +- examples/scripts/dpo_online.py | 3 +- examples/scripts/dpo_vlm.py | 3 +- examples/scripts/evals/judge_tldr.py | 2 +- examples/scripts/gkd.py | 3 +- examples/scripts/kto.py | 2 +- examples/scripts/nash_md.py | 3 +- examples/scripts/orpo.py | 3 +- examples/scripts/ppo/ppo.py | 2 +- examples/scripts/ppo/ppo_tldr.py | 2 +- examples/scripts/reward_modeling.py | 3 +- examples/scripts/rloo/rloo.py | 2 +- examples/scripts/rloo/rloo_tldr.py | 2 +- examples/scripts/sft.py | 2 +- examples/scripts/sft_video_llm.py | 22 ++++-------- examples/scripts/sft_vlm.py | 3 +- examples/scripts/sft_vlm_smol_vlm.py | 3 +- examples/scripts/xpo.py | 3 +- scripts/add_copyrights.py | 13 +++---- scripts/generate_tiny_models.py | 2 +- scripts/generate_zen_dataset.py | 2 +- scripts/log_example_reports.py | 1 + scripts/log_reports.py | 1 + setup.py | 2 +- tests/__init__.py | 3 +- tests/slow/__init__.py | 3 +- tests/slow/test_dpo_slow.py | 1 + tests/slow/test_sft_slow.py | 1 + tests/slow/testing_constants.py | 1 - tests/test_alignprop_trainer.py | 3 +- tests/test_bco_trainer.py | 1 + tests/test_best_of_n_sampler.py | 2 +- tests/test_callbacks.py | 2 +- tests/test_cli.py | 3 +- tests/test_cli_utils.py | 2 +- tests/test_core.py | 3 +- tests/test_cpo_trainer.py | 1 + tests/test_data_collator_completion_only.py | 3 +- tests/test_data_utils.py | 1 + tests/test_dataset_formatting.py | 2 +- tests/test_ddpo_trainer.py | 3 +- tests/test_dpo_trainer.py | 2 +- tests/test_environments.py | 2 +- tests/test_gkd_trainer.py | 1 + tests/test_iterative_sft_trainer.py | 3 +- tests/test_judges.py | 2 +- tests/test_kto_trainer.py | 1 + ...test_modeling_geometric_mixture_wrapper.py | 1 + tests/test_modeling_value_head.py | 3 +- tests/test_nash_md_trainer.py | 1 + tests/test_online_dpo_trainer.py | 1 + tests/test_orpo_trainer.py | 1 + tests/test_peft_models.py | 3 +- tests/test_ppo_trainer.py | 3 +- tests/test_reward_trainer.py | 3 +- tests/test_rich_progress_callback.py | 2 +- tests/test_rloo_trainer.py | 3 +- tests/test_sft_trainer.py | 3 +- tests/test_trainers_args.py | 2 +- tests/test_utils.py | 2 +- tests/test_xpo_trainer.py | 1 + tests/testing_constants.py | 2 +- tests/testing_utils.py | 3 +- trl/__init__.py | 2 +- trl/cli.py | 2 +- trl/core.py | 3 +- trl/data_utils.py | 1 + trl/env_utils.py | 34 ------------------- trl/environment/__init__.py | 2 +- trl/environment/base_environment.py | 2 +- trl/extras/__init__.py | 2 +- trl/extras/best_of_n_sampler.py | 2 +- trl/extras/dataset_formatting.py | 2 +- trl/import_utils.py | 3 +- trl/models/__init__.py | 2 +- trl/models/auxiliary_modules.py | 5 +-- trl/models/modeling_base.py | 3 +- trl/models/modeling_sd_base.py | 2 +- trl/models/modeling_value_head.py | 3 +- trl/models/sd_utils.py | 3 +- trl/models/utils.py | 2 +- trl/scripts/__init__.py | 2 +- trl/scripts/dpo.py | 3 +- trl/scripts/env.py | 2 +- trl/scripts/kto.py | 2 +- trl/scripts/sft.py | 3 +- trl/scripts/utils.py | 5 ++- trl/trainer/__init__.py | 2 +- trl/trainer/alignprop_config.py | 2 +- trl/trainer/alignprop_trainer.py | 3 +- trl/trainer/base.py | 2 +- trl/trainer/bco_config.py | 1 + trl/trainer/bco_trainer.py | 2 +- trl/trainer/callbacks.py | 3 +- trl/trainer/cpo_config.py | 1 + trl/trainer/cpo_trainer.py | 1 - trl/trainer/ddpo_config.py | 2 +- trl/trainer/ddpo_trainer.py | 2 +- trl/trainer/dpo_config.py | 1 + trl/trainer/dpo_trainer.py | 3 +- trl/trainer/gkd_config.py | 1 + trl/trainer/gkd_trainer.py | 1 + trl/trainer/iterative_sft_trainer.py | 3 +- trl/trainer/judges.py | 2 +- trl/trainer/kto_config.py | 1 + trl/trainer/kto_trainer.py | 2 +- trl/trainer/model_config.py | 2 +- trl/trainer/online_dpo_config.py | 2 +- trl/trainer/online_dpo_trainer.py | 2 +- trl/trainer/orpo_config.py | 1 + trl/trainer/orpo_trainer.py | 2 -- trl/trainer/ppo_config.py | 2 +- trl/trainer/ppo_trainer.py | 2 +- trl/trainer/reward_trainer.py | 3 +- trl/trainer/rloo_config.py | 2 +- trl/trainer/rloo_trainer.py | 2 +- trl/trainer/sft_config.py | 3 +- trl/trainer/sft_trainer.py | 3 +- trl/trainer/utils.py | 3 +- 145 files changed, 193 insertions(+), 186 deletions(-) delete mode 100644 trl/env_utils.py diff --git a/examples/datasets/hh-rlhf-helpful-base.py b/examples/datasets/hh-rlhf-helpful-base.py index 84d8010169..e089ed108e 100644 --- a/examples/datasets/hh-rlhf-helpful-base.py +++ b/examples/datasets/hh-rlhf-helpful-base.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/lm-human-preferences-descriptiveness.py b/examples/datasets/lm-human-preferences-descriptiveness.py index 2620b3101b..621757770c 100644 --- a/examples/datasets/lm-human-preferences-descriptiveness.py +++ b/examples/datasets/lm-human-preferences-descriptiveness.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/lm-human-preferences-sentiment.py b/examples/datasets/lm-human-preferences-sentiment.py index af0359ac38..a3eaa4d06e 100644 --- a/examples/datasets/lm-human-preferences-sentiment.py +++ b/examples/datasets/lm-human-preferences-sentiment.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/prm800k.py b/examples/datasets/prm800k.py index 244257912c..b5f95742be 100644 --- a/examples/datasets/prm800k.py +++ b/examples/datasets/prm800k.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/rlaif-v.py b/examples/datasets/rlaif-v.py index ec2501d4c7..84ae292f87 100644 --- a/examples/datasets/rlaif-v.py +++ b/examples/datasets/rlaif-v.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/tldr.py b/examples/datasets/tldr.py index e386095f88..0ae29481e3 100644 --- a/examples/datasets/tldr.py +++ b/examples/datasets/tldr.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/tldr_preference.py b/examples/datasets/tldr_preference.py index 0ac6af6646..1c4ff5bcbd 100644 --- a/examples/datasets/tldr_preference.py +++ b/examples/datasets/tldr_preference.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/tokenize_ds.py b/examples/datasets/tokenize_ds.py index ae52e0b22c..cd96a685a9 100644 --- a/examples/datasets/tokenize_ds.py +++ b/examples/datasets/tokenize_ds.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/ultrafeedback-prompt.py b/examples/datasets/ultrafeedback-prompt.py index 9753aa620a..3cb92467d5 100644 --- a/examples/datasets/ultrafeedback-prompt.py +++ b/examples/datasets/ultrafeedback-prompt.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/datasets/ultrafeedback.py b/examples/datasets/ultrafeedback.py index 5ca687760a..cb6c556d0c 100644 --- a/examples/datasets/ultrafeedback.py +++ b/examples/datasets/ultrafeedback.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py b/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py index 3d21a95257..f9f8018df7 100644 --- a/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py +++ b/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/stack_llama/scripts/reward_modeling.py b/examples/research_projects/stack_llama/scripts/reward_modeling.py index db38f62d4c..d3fc3c8505 100644 --- a/examples/research_projects/stack_llama/scripts/reward_modeling.py +++ b/examples/research_projects/stack_llama/scripts/reward_modeling.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/stack_llama/scripts/rl_training.py b/examples/research_projects/stack_llama/scripts/rl_training.py index a37cf63ab1..011c00554f 100644 --- a/examples/research_projects/stack_llama/scripts/rl_training.py +++ b/examples/research_projects/stack_llama/scripts/rl_training.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass, field from typing import Optional diff --git a/examples/research_projects/stack_llama/scripts/supervised_finetuning.py b/examples/research_projects/stack_llama/scripts/supervised_finetuning.py index c2d860468a..48488c8ef3 100644 --- a/examples/research_projects/stack_llama/scripts/supervised_finetuning.py +++ b/examples/research_projects/stack_llama/scripts/supervised_finetuning.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py index b3d287b144..31cc6bfcbc 100644 --- a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py index ef34a67f0c..56170b4818 100644 --- a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/tools/calculator.py b/examples/research_projects/tools/calculator.py index bde5692f62..5e76ee3416 100644 --- a/examples/research_projects/tools/calculator.py +++ b/examples/research_projects/tools/calculator.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/tools/python_interpreter.py b/examples/research_projects/tools/python_interpreter.py index 8f319b2d68..1870dbfb81 100644 --- a/examples/research_projects/tools/python_interpreter.py +++ b/examples/research_projects/tools/python_interpreter.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/tools/triviaqa.py b/examples/research_projects/tools/triviaqa.py index def5013582..5dadae6beb 100644 --- a/examples/research_projects/tools/triviaqa.py +++ b/examples/research_projects/tools/triviaqa.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/toxicity/scripts/evaluate-toxicity.py b/examples/research_projects/toxicity/scripts/evaluate-toxicity.py index fca714706f..17e25d55ac 100644 --- a/examples/research_projects/toxicity/scripts/evaluate-toxicity.py +++ b/examples/research_projects/toxicity/scripts/evaluate-toxicity.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py index 51f6d284c4..d3998c7882 100644 --- a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py +++ b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass, field from typing import Optional diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py index 1948080f4b..918619c1e8 100644 --- a/examples/scripts/alignprop.py +++ b/examples/scripts/alignprop.py @@ -1,4 +1,4 @@ -# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ Total Batch size = 128 = 4 (num_gpus) * 8 (per_device_batch) * 4 (accumulation steps) Feel free to reduce batch size or increasing truncated_rand_backprop_min to a higher value to reduce memory usage. diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index d1ef3ed465..94d83503fd 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py index db7efd1753..b81f3a3339 100644 --- a/examples/scripts/chat.py +++ b/examples/scripts/chat.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - ################################################################################################ # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/chat.py # ################################################################################################ diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 1eb9c1cc65..ab95a46cc0 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ Run the CPO training script with the following command with some example arguments. In general, the optimal configuration for CPO will be similar to that of DPO: diff --git a/examples/scripts/ddpo.py b/examples/scripts/ddpo.py index 92924c51e4..7919d5244a 100644 --- a/examples/scripts/ddpo.py +++ b/examples/scripts/ddpo.py @@ -1,4 +1,4 @@ -# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ python examples/scripts/ddpo.py \ --num_epochs=200 \ diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 4352ec603c..97425d3ef0 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - ############################################################################################### # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py # ############################################################################################### diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 729598aeb6..e885337a53 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ Usage: diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 395c8bf5d5..b01608c125 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ accelerate launch examples/scripts/dpo_vlm.py \ --dataset_name HuggingFaceH4/rlaif-v_formatted \ diff --git a/examples/scripts/evals/judge_tldr.py b/examples/scripts/evals/judge_tldr.py index 5f6a8ee662..f9e51df729 100644 --- a/examples/scripts/evals/judge_tldr.py +++ b/examples/scripts/evals/judge_tldr.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 84ac02228a..005e5c6257 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ # Full training: python examples/scripts/gkd.py \ diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 7dcd769a81..363e55e470 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 4d0c0f29f0..9eff14416f 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ Usage: diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index 19dbf12e51..67e086ea84 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ Run the ORPO training script with the following command with some example arguments. In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model: diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 8addc0f979..7873ac364e 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index d7864d3d01..22bb68586d 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 81658685fe..12888ef44d 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ Full training: python examples/scripts/reward_modeling.py \ diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 5eafb0de50..e0193bfebd 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 4bda783690..dd6ac51d6b 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 54fc8632ab..4b43634d47 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index a3381f94fc..9c9809750b 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -1,4 +1,4 @@ -# Copyright 2024. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ Example usage: accelerate launch \ @@ -53,20 +54,9 @@ from datasets import load_dataset from peft import LoraConfig from qwen_vl_utils import process_vision_info -from transformers import ( - AutoModelForVision2Seq, - AutoProcessor, - BitsAndBytesConfig, - Qwen2VLProcessor, -) - -from trl import ( - SFTConfig, - SFTTrainer, - get_kbit_device_map, -) -from trl.commands.cli_utils import SFTScriptArguments, TrlParser -from trl.trainer import ModelConfig +from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor + +from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map def download_video(url: str, cache_dir: str) -> str: @@ -161,7 +151,7 @@ def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: @dataclass -class CustomScriptArguments(SFTScriptArguments): +class CustomScriptArguments(ScriptArguments): video_cache_dir: str = "/tmp/videos/" diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index f7b9befc9e..dc849ee7cc 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ pip install pillow diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py index a88e7f552d..7710765d55 100644 --- a/examples/scripts/sft_vlm_smol_vlm.py +++ b/examples/scripts/sft_vlm_smol_vlm.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ pip install pillow diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index c3d8816546..3adb6862d4 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ Usage: diff --git a/scripts/add_copyrights.py b/scripts/add_copyrights.py index 9acd98bea4..e0c39d5988 100644 --- a/scripts/add_copyrights.py +++ b/scripts/add_copyrights.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ from datetime import datetime -COPYRIGHT_HEADER = f"""# Copyright {datetime.now().year} The HuggingFace Inc. team. All rights reserved. +COPYRIGHT_HEADER = f"""# Copyright {datetime.now().year} The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,8 +33,6 @@ # limitations under the License. """ -COPYRIGHT_KEYWORD = "# Copyright 20" - def get_tracked_python_files(): """Get a list of all tracked Python files using git.""" @@ -60,10 +58,9 @@ def check_and_add_copyright(file_path): with open(file_path, encoding="utf-8") as f: content = f.readlines() - # Check if the copyright header exists in the first 10 lines - for line in content[:10]: - if COPYRIGHT_KEYWORD in line: - return True + # Check if the exact copyright header exists + if "".join(content).startswith(COPYRIGHT_HEADER): + return True # If no copyright notice was found, prepend the header print(f"[MODIFY] Adding copyright to {file_path}.") diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index b15e9f6f65..d70bd51d4c 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/scripts/generate_zen_dataset.py b/scripts/generate_zen_dataset.py index e599e6c1ce..73c7c16f82 100644 --- a/scripts/generate_zen_dataset.py +++ b/scripts/generate_zen_dataset.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/scripts/log_example_reports.py b/scripts/log_example_reports.py index 10f6c9a7ad..feed6467ca 100644 --- a/scripts/log_example_reports.py +++ b/scripts/log_example_reports.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import argparse import logging import os diff --git a/scripts/log_reports.py b/scripts/log_reports.py index 0cdac4f756..762d057954 100644 --- a/scripts/log_reports.py +++ b/scripts/log_reports.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import argparse import json import logging diff --git a/setup.py b/setup.py index fd8ea87950..62b164a776 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/__init__.py b/tests/__init__.py index adfc257a24..196860c9f1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/slow/__init__.py b/tests/slow/__init__.py index adfc257a24..196860c9f1 100644 --- a/tests/slow/__init__.py +++ b/tests/slow/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index 81e3a8c0d2..b586b90de9 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import itertools import tempfile diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index ca8f999d57..8759fdb1af 100644 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import itertools import tempfile diff --git a/tests/slow/testing_constants.py b/tests/slow/testing_constants.py index cb6dad681c..f1c2ce6b7c 100644 --- a/tests/slow/testing_constants.py +++ b/tests/slow/testing_constants.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: push them under trl-org MODELS_TO_TEST = [ "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-MistralForCausalLM-0.2", diff --git a/tests/test_alignprop_trainer.py b/tests/test_alignprop_trainer.py index a7446639d1..092095f919 100644 --- a/tests/test_alignprop_trainer.py +++ b/tests/test_alignprop_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import unittest diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index 09f274ef65..0442a8f685 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest from functools import partial diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py index 02d0cb0de7..ab817ab723 100644 --- a/tests/test_best_of_n_sampler.py +++ b/tests/test_best_of_n_sampler.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index b0319179b6..d04a7804f8 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_cli.py b/tests/test_cli.py index d2bd5f8e7a..dc71657d66 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest from io import StringIO diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index a2343a2930..ac2f98d519 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_core.py b/tests/test_core.py index 2d8531d591..88ecf38fad 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import unittest import torch diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index ee110c8dab..744ee07aa6 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest diff --git a/tests/test_data_collator_completion_only.py b/tests/test_data_collator_completion_only.py index 7e8c6e224e..10c4a1036c 100644 --- a/tests/test_data_collator_completion_only.py +++ b/tests/test_data_collator_completion_only.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import unittest import torch diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 916417d0f4..a4eb13683c 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import itertools import unittest diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index b5193bfd92..e2790801b2 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_ddpo_trainer.py b/tests/test_ddpo_trainer.py index 65a626589a..ad4cc60b5e 100644 --- a/tests/test_ddpo_trainer.py +++ b/tests/test_ddpo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import unittest diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index ea9c916d76..fe2c732f64 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_environments.py b/tests/test_environments.py index 2555f3aeca..0b9fc56867 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py index 8f9e87456b..f6d15aa1ad 100644 --- a/tests/test_gkd_trainer.py +++ b/tests/test_gkd_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import tempfile import unittest diff --git a/tests/test_iterative_sft_trainer.py b/tests/test_iterative_sft_trainer.py index 099e058492..f0ee29dbc2 100644 --- a/tests/test_iterative_sft_trainer.py +++ b/tests/test_iterative_sft_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest from functools import partial diff --git a/tests/test_judges.py b/tests/test_judges.py index 13d2164ffe..0f8b83d881 100644 --- a/tests/test_judges.py +++ b/tests/test_judges.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index b705d0bc5a..5cb311d272 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest diff --git a/tests/test_modeling_geometric_mixture_wrapper.py b/tests/test_modeling_geometric_mixture_wrapper.py index c3d23b2c9d..97b27885a8 100644 --- a/tests/test_modeling_geometric_mixture_wrapper.py +++ b/tests/test_modeling_geometric_mixture_wrapper.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import unittest import torch diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index be4932e62f..545f46f95a 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import sys import tempfile diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py index c9e100cbb7..c0c0e0664c 100644 --- a/tests/test_nash_md_trainer.py +++ b/tests/test_nash_md_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index f29d491769..d5ea8271c5 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index 6592028ff2..d2eaee3947 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index db29334eda..3149f36bb3 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import tempfile import unittest diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 4667b6eb09..ad9b19a7b9 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import platform import subprocess diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index 1cee85b270..d4466a0404 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest diff --git a/tests/test_rich_progress_callback.py b/tests/test_rich_progress_callback.py index 84c287a86d..67deb8d0f1 100644 --- a/tests/test_rich_progress_callback.py +++ b/tests/test_rich_progress_callback.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 05ec037c9e..6861b6e83e 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import platform import subprocess import tempfile diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index d8a3706050..bb0a563d0d 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import copy import os import tempfile diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index b89194d1c9..704a20efad 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_utils.py b/tests/test_utils.py index d95cc7f2d6..f27240748f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index b611d93566..0c55025d5e 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import tempfile import unittest diff --git a/tests/testing_constants.py b/tests/testing_constants.py index c66f60e569..8b678dadf2 100644 --- a/tests/testing_constants.py +++ b/tests/testing_constants.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/testing_utils.py b/tests/testing_utils.py index c5c8b2e3ce..6c2041b7b3 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import random import unittest diff --git a/trl/__init__.py b/trl/__init__.py index 374384e834..7cda70d3c0 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/cli.py b/trl/cli.py index 1e02497927..d5a1421c51 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/core.py b/trl/core.py index d4e77f5fdc..f62c8bc414 100644 --- a/trl/core.py +++ b/trl/core.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import random import warnings diff --git a/trl/data_utils.py b/trl/data_utils.py index 88319626b8..9bc68f1d5c 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Any, Optional, Sequence, TypeVar from datasets import Dataset, DatasetDict diff --git a/trl/env_utils.py b/trl/env_utils.py deleted file mode 100644 index 64e98199e0..0000000000 --- a/trl/env_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Function `strtobool` copied and adapted from `distutils` (as deprected -# in Python 3.10). -# Reference: https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321 - - -def strtobool(val: str) -> bool: - """Convert a string representation of truth to True or False booleans. - - True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values - are 'n', 'no', 'f', 'false', 'off', and '0'. - - Raises: - ValueError: if 'val' is anything else. - """ - val = val.lower() - if val in ("y", "yes", "t", "true", "on", "1"): - return True - if val in ("n", "no", "f", "false", "off", "0"): - return False - raise ValueError(f"Invalid truth value, it should be a string but {val} was provided instead.") diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py index a2a39d747e..673c15fb65 100644 --- a/trl/environment/__init__.py +++ b/trl/environment/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py index fa7e21f91b..e230dd6ab6 100644 --- a/trl/environment/base_environment.py +++ b/trl/environment/base_environment.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/extras/__init__.py b/trl/extras/__init__.py index beedbd53f0..2bb2747d2b 100644 --- a/trl/extras/__init__.py +++ b/trl/extras/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/extras/best_of_n_sampler.py b/trl/extras/best_of_n_sampler.py index 646cee1318..c0f02152c2 100644 --- a/trl/extras/best_of_n_sampler.py +++ b/trl/extras/best_of_n_sampler.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/extras/dataset_formatting.py b/trl/extras/dataset_formatting.py index d84db4b49e..1be86580aa 100644 --- a/trl/extras/dataset_formatting.py +++ b/trl/extras/dataset_formatting.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/import_utils.py b/trl/import_utils.py index 31b78899f7..6ec64260e3 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import importlib import os from itertools import chain diff --git a/trl/models/__init__.py b/trl/models/__init__.py index 4cbfc0a511..1c6d01b7d5 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/models/auxiliary_modules.py b/trl/models/auxiliary_modules.py index 98c400c7cc..40b4ca407a 100644 --- a/trl/models/auxiliary_modules.py +++ b/trl/models/auxiliary_modules.py @@ -1,5 +1,5 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. - +# Copyright 2024 The HuggingFace Team. All rights reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import torch diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index c17b8650de..254fe29ac7 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json import logging import os diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py index 0b729a0009..131d8d8016 100644 --- a/trl/models/modeling_sd_base.py +++ b/trl/models/modeling_sd_base.py @@ -1,4 +1,4 @@ -# Copyright 2023 DDPO-pytorch authors (Kevin Black), The HuggingFace Team, metric-space. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index 592879ae3e..deb1f16f29 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, is_torch_npu_available, is_torch_xpu_available diff --git a/trl/models/sd_utils.py b/trl/models/sd_utils.py index a6bbbebf40..8e5c286d54 100644 --- a/trl/models/sd_utils.py +++ b/trl/models/sd_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ State dict utilities: utility methods for converting state dicts easily File copied from diffusers to avoid import issues and make TRL compatible diff --git a/trl/models/utils.py b/trl/models/utils.py index 0ff52461e0..53cf481f1f 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/scripts/__init__.py b/trl/scripts/__init__.py index e79624121d..2994ed5275 100644 --- a/trl/scripts/__init__.py +++ b/trl/scripts/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 382492b5b6..0ca30c7eed 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ # Full training python trl/scripts/dpo.py \ diff --git a/trl/scripts/env.py b/trl/scripts/env.py index c2000d1db4..57432263c4 100644 --- a/trl/scripts/env.py +++ b/trl/scripts/env.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py index a8aaf3c848..64fb71f02d 100644 --- a/trl/scripts/kto.py +++ b/trl/scripts/kto.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index 2ba8958081..6d12f435b5 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ # Full training python examples/scripts/sft.py \ diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index e7de20ef12..94ceb7cd39 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -1,6 +1,4 @@ -# This file is a copy of trl/examples/scripts/sft.py so that we could -# use it together with rich and the TRL CLI in a more customizable manner. -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import importlib import inspect import logging diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index e555b107e6..e5599756f7 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/alignprop_config.py b/trl/trainer/alignprop_config.py index 0efdcc74ce..1c4faa963e 100644 --- a/trl/trainer/alignprop_config.py +++ b/trl/trainer/alignprop_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py index 0fccba6a53..d03696c091 100644 --- a/trl/trainer/alignprop_trainer.py +++ b/trl/trainer/alignprop_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 AlignProp-pytorch authors (Mihir Prabhudesai), metric-space, The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import textwrap from collections import defaultdict diff --git a/trl/trainer/base.py b/trl/trainer/base.py index f0314cb987..7730e6af9a 100644 --- a/trl/trainer/base.py +++ b/trl/trainer/base.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/bco_config.py b/trl/trainer/bco_config.py index 61729576ec..10cd82b9f5 100644 --- a/trl/trainer/bco_config.py +++ b/trl/trainer/bco_config.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Any, Optional diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 89c2357a19..dc267de882 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1,4 +1,3 @@ -# BCO Authors: Seungjae Jung, Gunsoo Han, Daniel Wontae Nam and Kyoung-Woon On # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import inspect import os import random diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index dbee1f60a2..7e25366431 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os from typing import Optional, Union diff --git a/trl/trainer/cpo_config.py b/trl/trainer/cpo_config.py index 91d3008533..a288a8c75f 100644 --- a/trl/trainer/cpo_config.py +++ b/trl/trainer/cpo_config.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Any, Literal, Optional diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 9a068b21ac..49b2849766 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -1,4 +1,3 @@ -# CPO Authors: Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, Young Jin Kim # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/trl/trainer/ddpo_config.py b/trl/trainer/ddpo_config.py index 442689be8f..ca703eb806 100644 --- a/trl/trainer/ddpo_config.py +++ b/trl/trainer/ddpo_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py index 4ef0c82c82..a3b6813b4c 100644 --- a/trl/trainer/ddpo_trainer.py +++ b/trl/trainer/ddpo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 8a6e507dc1..88abdd4a5c 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import warnings from dataclasses import dataclass from enum import Enum diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index fe43b133d9..7ed0ac387f 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1,5 +1,4 @@ -# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/gkd_config.py b/trl/trainer/gkd_config.py index c826ceade2..e9b9d76363 100644 --- a/trl/trainer/gkd_config.py +++ b/trl/trainer/gkd_config.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Any, Optional diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index bc4f35ef33..173a8d6107 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import random import textwrap diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py index d2b02ab33b..9621822eb8 100644 --- a/trl/trainer/iterative_sft_trainer.py +++ b/trl/trainer/iterative_sft_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import warnings from typing import Callable, Optional, Union diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py index 71f86ef1b3..c29491340d 100644 --- a/trl/trainer/judges.py +++ b/trl/trainer/judges.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py index 351885c362..563d0cdbc9 100644 --- a/trl/trainer/kto_config.py +++ b/trl/trainer/kto_config.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Any, Literal, Optional diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 5513cf8d08..3aa3396e39 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1,4 +1,3 @@ -# KTO Authors: Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import inspect import os import random diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py index 7301d10213..3261ff8a6a 100644 --- a/trl/trainer/model_config.py +++ b/trl/trainer/model_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 34f8ec5837..0b06c79cb5 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 7830d3fe64..3d2b85536e 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py index 16aafbc787..b7e2ef7ad0 100644 --- a/trl/trainer/orpo_config.py +++ b/trl/trainer/orpo_config.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Any, Optional diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index e90fec8dfe..e6f148f90f 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -1,5 +1,3 @@ -# ORPO Authors: Jiwoo Hong, Noah Lee, and James Thorne -# Official code: https://github.com/xfactlab/orpo # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 465048184b..62a3b0a33e 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 2b55e5797c..b51c29c014 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 4000697336..3157900b83 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import inspect import os import warnings diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 2683fb5eee..e72b0fc02c 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index f2e3eb9674..6ad38d37c4 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index eefc22b267..310ede7870 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Any, Optional diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 34a95f36ec..809eff5679 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import dataclasses import inspect import os diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 02c8c2bee1..881f3d3873 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import dataclasses import importlib.resources as pkg_resources import json From 6941e0fadda8f5db4941d20885cc682424593366 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:40:23 +0100 Subject: [PATCH 35/42] =?UTF-8?q?=F0=9F=92=AC=20Fix=20chat=20for=20windows?= =?UTF-8?q?=20(#2443)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix chat for windows * add some tests back * Revert "add some tests back" This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06. From b202b15171e75d1e044f9c0f64e3199c9ad909fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:09:26 +0100 Subject: [PATCH 36/42] =?UTF-8?q?=F0=9F=86=94=20Add=20`datast=5Fconfig`=20?= =?UTF-8?q?to=20`ScriptArguments`=20(#2440)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * datast_config_name * Update trl/utils.py [ci skip] * sort import * typo [ci skip] * Trigger CI * Rename `dataset_config_name` to `dataset_config` --- examples/scripts/bco.py | 2 +- examples/scripts/cpo.py | 2 +- examples/scripts/dpo_online.py | 2 +- examples/scripts/dpo_vlm.py | 2 +- examples/scripts/gkd.py | 2 +- examples/scripts/kto.py | 2 +- examples/scripts/nash_md.py | 2 +- examples/scripts/orpo.py | 2 +- examples/scripts/ppo/ppo.py | 2 +- examples/scripts/ppo/ppo_tldr.py | 2 +- examples/scripts/reward_modeling.py | 2 +- examples/scripts/rloo/rloo.py | 2 +- examples/scripts/rloo/rloo_tldr.py | 2 +- examples/scripts/sft_video_llm.py | 2 +- examples/scripts/sft_vlm.py | 2 +- examples/scripts/sft_vlm_smol_vlm.py | 2 +- examples/scripts/xpo.py | 2 +- trl/scripts/utils.py | 6 +++--- 18 files changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index 94d83503fd..38a5d35a38 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -126,7 +126,7 @@ def mean_pooling(model_output, attention_mask): if tokenizer.chat_template is None: model, tokenizer = setup_chat_format(model, tokenizer) - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) accelerator = Accelerator() embedding_model = AutoModel.from_pretrained( diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index ab95a46cc0..1132e9b573 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -81,7 +81,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index e885337a53..185343e611 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -119,7 +119,7 @@ if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) trainer = OnlineDPOTrainer( model=model, diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index b01608c125..38023b1459 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -100,7 +100,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 005e5c6257..4408c2dfee 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -104,7 +104,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) with PartialState().local_main_process_first(): dataset = dataset.map( diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 363e55e470..7ae26931e9 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -91,7 +91,7 @@ model, tokenizer = setup_chat_format(model, tokenizer) # Load the dataset - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) # Initialize the KTO trainer trainer = KTOTrainer( diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 9eff14416f..eb17f728ae 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -117,7 +117,7 @@ if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) trainer = NashMDTrainer( model=model, diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index 67e086ea84..82578f99be 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -81,7 +81,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 7873ac364e..2758c950c5 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -116,7 +116,7 @@ # Dataset ################ dataset = load_dataset( - script_args.dataset_name, name=script_args.dataset_config_name, split=script_args.dataset_train_split + script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split ) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 22bb68586d..353a1493e3 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -122,7 +122,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) train_dataset = dataset[script_args.dataset_train_split] eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 12888ef44d..a3f299266b 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -105,7 +105,7 @@ ############## # Load dataset ############## - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ########## # Training diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index e0193bfebd..85c443b7ae 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -89,7 +89,7 @@ # Dataset ################ dataset = load_dataset( - script_args.dataset_name, name=script_args.dataset_config_name, split=script_args.dataset_train_split + script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split ) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index dd6ac51d6b..cf4265e921 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -90,7 +90,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) train_dataset = dataset[script_args.dataset_train_split] eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index 9c9809750b..4a85114d4f 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -166,7 +166,7 @@ class CustomScriptArguments(ScriptArguments): training_args.dataset_kwargs = {"skip_prepare_dataset": True} # Load dataset - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name, split="train") + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train") # Setup model torch_dtype = ( diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index dc849ee7cc..497bb69b66 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -107,7 +107,7 @@ def collate_fn(examples): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py index 7710765d55..278a38621f 100644 --- a/examples/scripts/sft_vlm_smol_vlm.py +++ b/examples/scripts/sft_vlm_smol_vlm.py @@ -119,7 +119,7 @@ def collate_fn(examples): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 3adb6862d4..726b457b2e 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -102,7 +102,7 @@ if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) trainer = XPOTrainer( model=model, diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index 94ceb7cd39..a6637c02a3 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -39,8 +39,8 @@ class ScriptArguments: Args: dataset_name (`str`): Dataset name. - dataset_config_name (`str` or `None`, *optional*, defaults to `None`): - Dataset configuration name. + dataset_config (`str` or `None`, *optional*, defaults to `None`): + Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. dataset_train_split (`str`, *optional*, defaults to `"train"`): Dataset split to use for training. dataset_test_split (`str`, *optional*, defaults to `"test"`): @@ -53,7 +53,7 @@ class ScriptArguments: """ dataset_name: str - dataset_config_name: Optional[str] = None + dataset_config: Optional[str] = None dataset_train_split: str = "train" dataset_test_split: str = "test" gradient_checkpointing_use_reentrant: bool = False From 2401463bbf38ed29f5e1744e42297c285ba4e19c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 10 Dec 2024 12:40:13 +0100 Subject: [PATCH 37/42] =?UTF-8?q?=F0=9F=8F=8E=20Fix=20deepspeed=20preparat?= =?UTF-8?q?ion=20of=20`ref=5Fmodel`=20in=20`OnlineDPOTrainer`=20(#2417)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove unused deepspeed code * add model prep back * add deepspeed even if it doesn't work * rm old code --- trl/trainer/online_dpo_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 3d2b85536e..c014ce1e13 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -284,7 +284,10 @@ def __init__( self.reward_model = prepare_deepspeed( self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 ) - self.ref_model = prepare_deepspeed(self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16) + if self.ref_model is not None: + self.ref_model = prepare_deepspeed( + self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) else: if self.ref_model is not None: self.ref_model = self.ref_model.to(self.accelerator.device) From c0209f98f015fc289298b7a4dbb7213d2421c2de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 10 Dec 2024 12:26:30 +0000 Subject: [PATCH 38/42] Fix config name --- tests/test_cli.py | 6 +++--- trl/scripts/dpo.py | 2 +- trl/scripts/kto.py | 2 +- trl/scripts/sft.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index dc71657d66..f7719ecc16 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -24,7 +24,7 @@ class TestCLI(unittest.TestCase): def test_dpo(self): with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory - command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_preference --report_to none" + command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_preference --report_to none" with patch("sys.argv", command.split(" ")): main() @@ -37,13 +37,13 @@ def test_env(self, mock_stdout): def test_kto(self): with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory - command = f"trl kto --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_unpaired_preference --report_to none" + command = f"trl kto --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_unpaired_preference --report_to none" with patch("sys.argv", command.split(" ")): main() def test_sft(self): with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory - command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config_name standard_language_modeling --report_to none" + command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_language_modeling --report_to none" with patch("sys.argv", command.split(" ")): main() diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 0ca30c7eed..da45055a1e 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -108,7 +108,7 @@ def main(script_args, training_args, model_args): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ########## # Training diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py index 64fb71f02d..75aa8c5081 100644 --- a/trl/scripts/kto.py +++ b/trl/scripts/kto.py @@ -91,7 +91,7 @@ def main(script_args, training_args, model_args): model, tokenizer = setup_chat_format(model, tokenizer) # Load the dataset - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) # Initialize the KTO trainer trainer = KTOTrainer( diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index 6d12f435b5..ad333b364c 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -89,7 +89,7 @@ def main(script_args, training_args, model_args): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training From 0263435f8c3b8b7b0cc2dce2ef6e980e29be7628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 13 Dec 2024 15:10:45 +0000 Subject: [PATCH 39/42] Remove `make dev` in favor of `pip install -e .[dev]` --- CONTRIBUTING.md | 4 ++-- Makefile | 7 ------- README.md | 2 +- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d718736307..13983328ea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -33,7 +33,7 @@ For something slightly more challenging, you can also take a look at the [Good S Before you start contributing make sure you have installed all the dev tools: ```bash -make dev +pip install -e .[dev] ``` ## Fixing outstanding issues @@ -152,7 +152,7 @@ Follow these steps to start contributing: 4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library: ```bash - $ make dev + $ pip install -e .[dev] ``` (If TRL was already installed in the virtual environment, remove diff --git a/Makefile b/Makefile index 704cacbff2..cb913374c0 100644 --- a/Makefile +++ b/Makefile @@ -5,13 +5,6 @@ check_dirs := examples tests trl ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs COMMAND_FILES_PATH = `pwd`/commands - -dev: - @if [ -L "$(pwd)/trl/commands/scripts" ]; then unlink "$(pwd)/trl/commands/scripts"; fi - @if [ -e "$(pwd)/trl/commands/scripts" ] && [ ! -L "$(pwd)/trl/commands/scripts" ]; then rm -rf "$(pwd)/trl/commands/scripts"; fi - pip install -e ".[dev]" - ln -s `pwd`/examples/scripts/ `pwd`/trl/commands - test: python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/ diff --git a/README.md b/README.md index f9895fe2eb..d28607fae3 100644 --- a/README.md +++ b/README.md @@ -198,7 +198,7 @@ If you want to contribute to `trl` or customize it to your needs make sure to re ```bash git clone https://github.com/huggingface/trl.git cd trl/ -make dev +pip install -e .[dev] ``` ## Citation From 98458a0188de4e24d017aed5553a329b0bc73e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 13 Dec 2024 15:45:58 +0000 Subject: [PATCH 40/42] Update script paths and remove old symlink related things --- .gitignore | 3 --- commands/run_dpo.sh | 2 +- commands/run_sft.sh | 2 +- docs/source/clis.mdx | 4 ++-- docs/source/dpo_trainer.mdx | 4 ++-- docs/source/example_overview.md | 5 +---- docs/source/kto_trainer.mdx | 4 ++-- docs/source/lora_tuning_peft.mdx | 2 +- docs/source/sft_trainer.mdx | 2 +- examples/scripts/kto.py | 4 ++-- setup.py | 4 +--- trl/scripts/dpo.py | 2 +- trl/scripts/kto.py | 4 ++-- trl/scripts/sft.py | 4 ++-- 14 files changed, 19 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index 30ae97a160..19b9bb4284 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,3 @@ checklink/cookies.txt nbs/wandb/ examples/notebooks/wandb/ wandb/ - -# cli scripts that are symlinked from `examples/scripts` -trl/commands/scripts/ \ No newline at end of file diff --git a/commands/run_dpo.sh b/commands/run_dpo.sh index b394df5b65..f34b12cbb1 100644 --- a/commands/run_dpo.sh +++ b/commands/run_dpo.sh @@ -35,7 +35,7 @@ CMD=""" accelerate launch $EXTRA_ACCELERATE_ARGS \ --num_processes $NUM_GPUS \ --mixed_precision 'fp16' \ - `pwd`/examples/scripts/dpo.py \ + `pwd`/trl/scripts/dpo.py \ --model_name_or_path $MODEL_NAME \ --dataset_name $DATASET_NAME \ --output_dir $OUTPUT_DIR \ diff --git a/commands/run_sft.sh b/commands/run_sft.sh index f564370ab4..bdea77fcb6 100644 --- a/commands/run_sft.sh +++ b/commands/run_sft.sh @@ -36,7 +36,7 @@ CMD=""" accelerate launch $EXTRA_ACCELERATE_ARGS \ --num_processes $NUM_GPUS \ --mixed_precision 'fp16' \ - `pwd`/examples/scripts/sft.py \ + `pwd`/trl/scripts/sft.py \ --model_name $MODEL_NAME \ --dataset_name $DATASET_NAME \ --output_dir $OUTPUT_DIR \ diff --git a/docs/source/clis.mdx b/docs/source/clis.mdx index 016b7bc647..56220d612d 100644 --- a/docs/source/clis.mdx +++ b/docs/source/clis.mdx @@ -64,7 +64,7 @@ Follow the basic instructions above and run `trl sft --output_dir < trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb ``` -The SFT CLI is based on the `examples/scripts/sft.py` script. +The SFT CLI is based on the `trl/scripts/sft.py` script. ### Direct Policy Optimization (DPO) @@ -87,7 +87,7 @@ trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --datase ``` -The DPO CLI is based on the `examples/scripts/dpo.py` script. +The DPO CLI is based on the `trl/scripts/dpo.py` script. #### Custom preference dataset diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 068f18b312..78fb391240 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -112,12 +112,12 @@ For a complete example of fine-tuning a vision-language model, refer to the scri ## Example script -We provide an example script to train a model using the DPO method. The script is available in [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) +We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: ```bash -accelerate launch examples/scripts/dpo.py \ +accelerate launch trl/scripts/dpo.py \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --dataset_name trl-lib/ultrafeedback_binarized \ --num_train_epochs 1 \ diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index d239199810..f4260edb75 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -31,7 +31,7 @@ Then, it is encouraged to launch jobs with `accelerate launch`! # Maintained Examples - +Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly. | File | Description | | ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | @@ -41,13 +41,10 @@ Then, it is encouraged to launch jobs with `accelerate launch`! | [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. | | [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. | -| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | -| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. | | [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language | | [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | | [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. | -| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. | | [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. | Here are also some easier-to-run colab notebooks that you can use to get started with TRL: diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 1ed6a33613..7b79268410 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -80,12 +80,12 @@ In theory, the dataset should contain at least one chosen and one rejected compl ## Example script -We provide an example script to train a model using the KTO method. The script is available in [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) +We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py) To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command: ```bash -accelerate launch examples/scripts/kto.py \ +accelerate launch trl/scripts/kto.py \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --dataset_name trl-lib/kto-mix-14k \ --num_train_epochs 1 \ diff --git a/docs/source/lora_tuning_peft.mdx b/docs/source/lora_tuning_peft.mdx index 531ee0fcd7..8906107c8e 100644 --- a/docs/source/lora_tuning_peft.mdx +++ b/docs/source/lora_tuning_peft.mdx @@ -140,5 +140,5 @@ python PATH_TO_SCRIPT You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB): ```bash -python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2 +python trl/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2 ``` diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index c45069d18c..4f33eff8aa 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -4,7 +4,7 @@ Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset. -Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py). +Check out a complete flexible example at [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py). Experimental support for Vision Language Models is also included in the example [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py). ## Quickstart diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 7ae26931e9..d68c0358dd 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -16,7 +16,7 @@ Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. # Full training: -python examples/scripts/kto.py \ +python trl/scripts/kto.py \ --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 16 \ @@ -33,7 +33,7 @@ --logging_first_step # QLoRA: -python examples/scripts/kto.py \ +python trl/scripts/kto.py \ --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 8 \ diff --git a/setup.py b/setup.py index 62b164a776..99da9c2ac0 100644 --- a/setup.py +++ b/setup.py @@ -119,9 +119,7 @@ "console_scripts": ["trl=trl.cli:main"], }, include_package_data=True, - package_data={ - "trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md"], - }, + package_data={"trl": ["templates/*.md"],}, packages=find_packages(exclude={"tests", "tests.slow"}), install_requires=REQUIRED_PKGS, extras_require=EXTRAS, diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index da45055a1e..69b779e391 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -29,7 +29,7 @@ --no_remove_unused_columns # LoRA: -python examples/scripts/dpo.py \ +python trl/scripts/dpo.py \ --dataset_name trl-lib/ultrafeedback_binarized \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --learning_rate 5.0e-6 \ diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py index 75aa8c5081..9eb44ba09f 100644 --- a/trl/scripts/kto.py +++ b/trl/scripts/kto.py @@ -16,7 +16,7 @@ Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. # Full training: -python examples/scripts/kto.py \ +python trl/scripts/kto.py \ --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 16 \ @@ -33,7 +33,7 @@ --logging_first_step # QLoRA: -python examples/scripts/kto.py \ +python trl/scripts/kto.py \ --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 8 \ diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index ad333b364c..f457b1dc68 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -14,7 +14,7 @@ """ # Full training -python examples/scripts/sft.py \ +python trl/scripts/sft.py \ --model_name_or_path Qwen/Qwen2-0.5B \ --dataset_name trl-lib/Capybara \ --learning_rate 2.0e-5 \ @@ -30,7 +30,7 @@ --push_to_hub # LoRA -python examples/scripts/sft.py \ +python trl/scripts/sft.py \ --model_name_or_path Qwen/Qwen2-0.5B \ --dataset_name trl-lib/Capybara \ --learning_rate 2.0e-4 \ From 65f31f6289d596b44bf19ff39149a9d1dadb25a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 13 Dec 2024 16:05:42 +0000 Subject: [PATCH 41/42] Fix chat script path [ci skip] --- docs/source/clis.mdx | 2 -- docs/source/example_overview.md | 1 - 2 files changed, 3 deletions(-) diff --git a/docs/source/clis.mdx b/docs/source/clis.mdx index 56220d612d..9c7a2dfca8 100644 --- a/docs/source/clis.mdx +++ b/docs/source/clis.mdx @@ -123,8 +123,6 @@ Besides talking to the model there are a few commands you can use: - `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided - `exit`: closes the interface -The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters. - ## Getting the system information You can get the system information by running the following command: diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index f4260edb75..e7e3575762 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -37,7 +37,6 @@ Scripts can be used as examples of how to use TRL trainers. They are located in | ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. | | [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. | -| [`examples/scripts/chat.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/chat.py) | This script allows you to load and use a model as a chatbot. | | [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. | | [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. | From 3a3be533af62d654328f48762e2b5b4ac1b359c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 13 Dec 2024 16:06:52 +0000 Subject: [PATCH 42/42] style --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 99da9c2ac0..28d483e0c8 100644 --- a/setup.py +++ b/setup.py @@ -119,7 +119,9 @@ "console_scripts": ["trl=trl.cli:main"], }, include_package_data=True, - package_data={"trl": ["templates/*.md"],}, + package_data={ + "trl": ["templates/*.md"], + }, packages=find_packages(exclude={"tests", "tests.slow"}), install_requires=REQUIRED_PKGS, extras_require=EXTRAS,