diff --git a/.gitignore b/.gitignore index 30ae97a160..19b9bb4284 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,3 @@ checklink/cookies.txt nbs/wandb/ examples/notebooks/wandb/ wandb/ - -# cli scripts that are symlinked from `examples/scripts` -trl/commands/scripts/ \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d718736307..13983328ea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -33,7 +33,7 @@ For something slightly more challenging, you can also take a look at the [Good S Before you start contributing make sure you have installed all the dev tools: ```bash -make dev +pip install -e .[dev] ``` ## Fixing outstanding issues @@ -152,7 +152,7 @@ Follow these steps to start contributing: 4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library: ```bash - $ make dev + $ pip install -e .[dev] ``` (If TRL was already installed in the virtual environment, remove diff --git a/Makefile b/Makefile index 704cacbff2..cb913374c0 100644 --- a/Makefile +++ b/Makefile @@ -5,13 +5,6 @@ check_dirs := examples tests trl ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs COMMAND_FILES_PATH = `pwd`/commands - -dev: - @if [ -L "$(pwd)/trl/commands/scripts" ]; then unlink "$(pwd)/trl/commands/scripts"; fi - @if [ -e "$(pwd)/trl/commands/scripts" ] && [ ! -L "$(pwd)/trl/commands/scripts" ]; then rm -rf "$(pwd)/trl/commands/scripts"; fi - pip install -e ".[dev]" - ln -s `pwd`/examples/scripts/ `pwd`/trl/commands - test: python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/ diff --git a/README.md b/README.md index f9895fe2eb..d28607fae3 100644 --- a/README.md +++ b/README.md @@ -198,7 +198,7 @@ If you want to contribute to `trl` or customize it to your needs make sure to re ```bash git clone https://github.com/huggingface/trl.git cd trl/ -make dev +pip install -e .[dev] ``` ## Citation diff --git a/commands/run_dpo.sh b/commands/run_dpo.sh index b394df5b65..f34b12cbb1 100644 --- a/commands/run_dpo.sh +++ b/commands/run_dpo.sh @@ -35,7 +35,7 @@ CMD=""" accelerate launch $EXTRA_ACCELERATE_ARGS \ --num_processes $NUM_GPUS \ --mixed_precision 'fp16' \ - `pwd`/examples/scripts/dpo.py \ + `pwd`/trl/scripts/dpo.py \ --model_name_or_path $MODEL_NAME \ --dataset_name $DATASET_NAME \ --output_dir $OUTPUT_DIR \ diff --git a/commands/run_sft.sh b/commands/run_sft.sh index f564370ab4..bdea77fcb6 100644 --- a/commands/run_sft.sh +++ b/commands/run_sft.sh @@ -36,7 +36,7 @@ CMD=""" accelerate launch $EXTRA_ACCELERATE_ARGS \ --num_processes $NUM_GPUS \ --mixed_precision 'fp16' \ - `pwd`/examples/scripts/sft.py \ + `pwd`/trl/scripts/sft.py \ --model_name $MODEL_NAME \ --dataset_name $DATASET_NAME \ --output_dir $OUTPUT_DIR \ diff --git a/docs/source/clis.mdx b/docs/source/clis.mdx index 0e600e5d99..9c7a2dfca8 100644 --- a/docs/source/clis.mdx +++ b/docs/source/clis.mdx @@ -4,8 +4,14 @@ You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SF Currently supported CLIs are: -- `trl sft`: fine-tune a LLM on a text/instruction dataset -- `trl dpo`: fine-tune a LLM with DPO on a preference dataset +#### Training commands + +- `trl dpo`: fine-tune a LLM with DPO +- `trl kto`: fine-tune a LLM with KTO +- `trl sft`: fine-tune a LLM with SFT + +#### Other commands + - `trl chat`: quickly spin up a LLM fine-tuned for chatting - `trl env`: get the system information @@ -58,7 +64,7 @@ Follow the basic instructions above and run `trl sft --output_dir < trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb ``` -The SFT CLI is based on the `examples/scripts/sft.py` script. +The SFT CLI is based on the `trl/scripts/sft.py` script. ### Direct Policy Optimization (DPO) @@ -81,7 +87,7 @@ trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --datase ``` -The DPO CLI is based on the `examples/scripts/dpo.py` script. +The DPO CLI is based on the `trl/scripts/dpo.py` script. #### Custom preference dataset @@ -117,8 +123,6 @@ Besides talking to the model there are a few commands you can use: - `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided - `exit`: closes the interface -The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters. - ## Getting the system information You can get the system information by running the following command: diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 068f18b312..78fb391240 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -112,12 +112,12 @@ For a complete example of fine-tuning a vision-language model, refer to the scri ## Example script -We provide an example script to train a model using the DPO method. The script is available in [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) +We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: ```bash -accelerate launch examples/scripts/dpo.py \ +accelerate launch trl/scripts/dpo.py \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --dataset_name trl-lib/ultrafeedback_binarized \ --num_train_epochs 1 \ diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index d239199810..e7e3575762 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -31,23 +31,19 @@ Then, it is encouraged to launch jobs with `accelerate launch`! # Maintained Examples - +Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly. | File | Description | | ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. | | [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. | -| [`examples/scripts/chat.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/chat.py) | This script allows you to load and use a model as a chatbot. | | [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. | | [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. | -| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | -| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. | | [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language | | [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | | [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. | -| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. | | [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. | Here are also some easier-to-run colab notebooks that you can use to get started with TRL: diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 1ed6a33613..7b79268410 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -80,12 +80,12 @@ In theory, the dataset should contain at least one chosen and one rejected compl ## Example script -We provide an example script to train a model using the KTO method. The script is available in [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) +We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py) To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command: ```bash -accelerate launch examples/scripts/kto.py \ +accelerate launch trl/scripts/kto.py \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --dataset_name trl-lib/kto-mix-14k \ --num_train_epochs 1 \ diff --git a/docs/source/lora_tuning_peft.mdx b/docs/source/lora_tuning_peft.mdx index 531ee0fcd7..8906107c8e 100644 --- a/docs/source/lora_tuning_peft.mdx +++ b/docs/source/lora_tuning_peft.mdx @@ -140,5 +140,5 @@ python PATH_TO_SCRIPT You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB): ```bash -python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2 +python trl/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2 ``` diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index c45069d18c..4f33eff8aa 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -4,7 +4,7 @@ Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset. -Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py). +Check out a complete flexible example at [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py). Experimental support for Vision Language Models is also included in the example [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py). ## Quickstart diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py index 12e7c448d4..b81f3a3339 100644 --- a/examples/scripts/chat.py +++ b/examples/scripts/chat.py @@ -12,363 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import json -import os -import platform -import re -import sys -import time -from threading import Thread - -import torch -from rich.console import Console -from rich.live import Live -from rich.markdown import Markdown -from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer - -from trl import TrlParser, init_zero_verbose -from trl.commands.cli_utils import ChatArguments -from trl.trainer.utils import get_quantization_config - - -if platform.system() != "Windows": - import pwd - - -init_zero_verbose() - -HELP_STRING = """\ - -**TRL CHAT INTERFACE** - -The chat interface is a simple tool to try out a chat model. - -Besides talking to the model there are several commands: -- **clear**: clears the current conversation and start a new one -- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input -- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';'). -- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set** -- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided -- **exit**: closes the interface -""" - -SUPPORTED_GENERATION_KWARGS = [ - "max_new_tokens", - "do_sample", - "num_beams", - "temperature", - "top_p", - "top_k", - "repetition_penalty", -] - -SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$" - - -class RichInterface: - def __init__(self, model_name=None, user_name=None): - self._console = Console() - if model_name is None: - self.model_name = "assistant" - else: - self.model_name = model_name - if user_name is None: - self.user_name = "user" - else: - self.user_name = user_name - - def stream_output(self, output_stream): - """Stream output from a role.""" - # This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py - # Create a Live context for updating the console output - text = "" - self._console.print(f"[bold blue]<{self.model_name}>:") - with Live(console=self._console, refresh_per_second=4) as live: - # Read lines from the stream - for i, outputs in enumerate(output_stream): - if not outputs or i == 0: - continue - text += outputs - # Render the accumulated text as Markdown - # NOTE: this is a workaround for the rendering "unstandard markdown" - # in rich. The chatbots output treat "\n" as a new line for - # better compatibility with real-world text. However, rendering - # in markdown would break the format. It is because standard markdown - # treat a single "\n" in normal text as a space. - # Our workaround is adding two spaces at the end of each line. - # This is not a perfect solution, as it would - # introduce trailing spaces (only) in code block, but it works well - # especially for console output, because in general the console does not - # care about trailing spaces. - lines = [] - for line in text.splitlines(): - lines.append(line) - if line.startswith("```"): - # Code block marker - do not add trailing spaces, as it would - # break the syntax highlighting - lines.append("\n") - else: - lines.append(" \n") - markdown = Markdown("".join(lines).strip(), code_theme="github-dark") - # Update the Live console output - live.update(markdown) - self._console.print() - return text - - def input(self): - input = self._console.input(f"[bold red]<{self.user_name}>:\n") - self._console.print() - return input - - def clear(self): - self._console.clear() - - def print_user_message(self, text): - self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}") - self._console.print() - - def print_green(self, text): - self._console.print(f"[bold green]{text}") - self._console.print() - - def print_red(self, text): - self._console.print(f"[bold red]{text}") - self._console.print() - - def print_help(self): - self._console.print(Markdown(HELP_STRING)) - self._console.print() - - -def get_username(): - if platform.system() == "Windows": - return os.getlogin() - else: - return pwd.getpwuid(os.getuid()).pw_name - - -def create_default_filename(model_name): - time_str = time.strftime("%Y-%m-%d_%H-%M-%S") - return f"{model_name}/chat_{time_str}.json" - - -def save_chat(chat, args, filename): - output_dict = {} - output_dict["settings"] = vars(args) - output_dict["chat_history"] = chat - - folder = args.save_folder - - if filename is None: - filename = create_default_filename(args.model_name_or_path) - filename = os.path.join(folder, filename) - os.makedirs(os.path.dirname(filename), exist_ok=True) - - with open(filename, "w") as f: - json.dump(output_dict, f, indent=4) - return os.path.abspath(filename) - - -def clear_chat_history(system_prompt): - if system_prompt is None: - chat = [] - else: - chat = [{"role": "system", "content": system_prompt}] - return chat - - -def parse_settings(user_input, current_args, interface): - settings = user_input[4:].strip().split(";") - settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings] - settings = dict(settings) - error = False - - for name in settings: - if hasattr(current_args, name): - try: - if isinstance(getattr(current_args, name), bool): - if settings[name] == "True": - settings[name] = True - elif settings[name] == "False": - settings[name] = False - else: - raise ValueError - else: - settings[name] = type(getattr(current_args, name))(settings[name]) - except ValueError: - interface.print_red( - f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}." - ) - else: - interface.print_red(f"There is no '{name}' setting.") - - if error: - interface.print_red("There was an issue parsing the settings. No settings have been changed.") - return current_args, False - else: - for name in settings: - setattr(current_args, name, settings[name]) - interface.print_green(f"Set {name} to {settings[name]}.") - - time.sleep(1.5) # so the user has time to read the changes - return current_args, True - - -def load_model_and_tokenizer(args): - tokenizer = AutoTokenizer.from_pretrained( - args.model_name_or_path, - revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - ) - - torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype) - quantization_config = get_quantization_config(args) - model_kwargs = dict( - revision=args.model_revision, - attn_implementation=args.attn_implementation, - torch_dtype=torch_dtype, - device_map="auto", - quantization_config=quantization_config, - ) - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs - ) - - if getattr(model, "hf_device_map", None) is None: - model = model.to(args.device) - - return model, tokenizer - - -def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids): - if tokenizer.pad_token_id is None: - pad_token_id = tokenizer.eos_token_id - else: - pad_token_id = tokenizer.pad_token_id - - all_eos_token_ids = [] - - if eos_tokens is not None: - all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(","))) - - if eos_token_ids is not None: - all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")]) - - if len(all_eos_token_ids) == 0: - all_eos_token_ids.append(tokenizer.eos_token_id) - - return pad_token_id, all_eos_token_ids - - -def chat_cli(): - parser = TrlParser(ChatArguments) - - if "--config" not in sys.argv: - sys.argv.append("--config") - sys.argv.append(os.path.join(os.path.dirname(__file__), "config/default_chat_config.yaml")) - args = parser.parse_args_and_config()[0] - if args.examples is None: - args.examples = {} - - current_args = copy.deepcopy(args) - - if args.user is None: - user = get_username() - else: - user = args.user - - model, tokenizer = load_model_and_tokenizer(args) - generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) - - pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) - - interface = RichInterface(model_name=args.model_name_or_path, user_name=user) - interface.clear() - chat = clear_chat_history(current_args.system_prompt) - while True: - try: - user_input = interface.input() - - if user_input == "clear": - chat = clear_chat_history(current_args.system_prompt) - interface.clear() - continue - - if user_input == "help": - interface.print_help() - continue - - if user_input == "exit": - break - - if user_input == "reset": - interface.clear() - current_args = copy.deepcopy(args) - chat = clear_chat_history(current_args.system_prompt) - continue - - if user_input.startswith("save") and len(user_input.split()) < 2: - split_input = user_input.split() - - if len(split_input) == 2: - filename = split_input[1] - else: - filename = None - filename = save_chat(chat, current_args, filename) - interface.print_green(f"Chat saved in {filename}!") - continue - - if re.match(SETTING_RE, user_input): - current_args, success = parse_settings(user_input, current_args, interface) - if success: - chat = [] - interface.clear() - continue - - if user_input.startswith("example") and len(user_input.split()) == 2: - example_name = user_input.split()[1] - if example_name in current_args.examples: - interface.clear() - chat = [] - interface.print_user_message(current_args.examples[example_name]["text"]) - user_input = current_args.examples[example_name]["text"] - else: - interface.print_red( - f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}." - ) - continue - - chat.append({"role": "user", "content": user_input}) - - inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( - model.device - ) - attention_mask = torch.ones_like(inputs) - generation_kwargs = dict( - inputs=inputs, - attention_mask=attention_mask, - streamer=generation_streamer, - max_new_tokens=current_args.max_new_tokens, - do_sample=current_args.do_sample, - num_beams=current_args.num_beams, - temperature=current_args.temperature, - top_k=current_args.top_k, - top_p=current_args.top_p, - repetition_penalty=current_args.repetition_penalty, - pad_token_id=pad_token_id, - eos_token_id=eos_token_ids, - ) - - thread = Thread(target=model.generate, kwargs=generation_kwargs) - thread.start() - model_output = interface.stream_output(generation_streamer) - thread.join() - chat.append({"role": "assistant", "content": model_output}) - - except KeyboardInterrupt: - break - - -if __name__ == "__main__": - chat_cli() +################################################################################################ +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/chat.py # +################################################################################################ diff --git a/examples/scripts/config/default_chat_config.yaml b/examples/scripts/config/default_chat_config.yaml deleted file mode 100644 index 93195f9d7d..0000000000 --- a/examples/scripts/config/default_chat_config.yaml +++ /dev/null @@ -1,13 +0,0 @@ -examples: - llama: - text: There is a Llama in my lawn, how can I get rid of it? - code: - text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]. - helicopter: - text: How many helicopters can a human eat in one sitting? - numbers: - text: Count to 10 but skip every number ending with an 'e' - birds: - text: Why aren't birds real? - socks: - text: Why is it important to eat socks after meditating? \ No newline at end of file diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 08b0f18db7..97425d3ef0 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -12,126 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -# Full training -python examples/scripts/dpo.py \ - --dataset_name trl-lib/ultrafeedback_binarized \ - --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ - --learning_rate 5.0e-7 \ - --num_train_epochs 1 \ - --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 8 \ - --gradient_checkpointing \ - --logging_steps 25 \ - --eval_strategy steps \ - --eval_steps 50 \ - --output_dir Qwen2-0.5B-DPO \ - --no_remove_unused_columns - -# LoRA: -python examples/scripts/dpo.py \ - --dataset_name trl-lib/ultrafeedback_binarized \ - --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ - --learning_rate 5.0e-6 \ - --num_train_epochs 1 \ - --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 8 \ - --gradient_checkpointing \ - --logging_steps 25 \ - --eval_strategy steps \ - --eval_steps 50 \ - --output_dir Qwen2-0.5B-DPO \ - --no_remove_unused_columns \ - --use_peft \ - --lora_r 32 \ - --lora_alpha 16 -""" - -import torch -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer - -from trl import ( - DPOConfig, - DPOTrainer, - ModelConfig, - ScriptArguments, - TrlParser, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE - - -if __name__ == "__main__": - parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) - script_args, training_args, model_args = parser.parse_args_and_config() - - ################ - # Model & Tokenizer - ################### - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs - ) - peft_config = get_peft_config(model_args) - if peft_config is None: - ref_model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs - ) - else: - ref_model = None - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code - ) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE - if script_args.ignore_bias_buffers: - # torch distributed hack - model._ddp_params_and_buffers_to_ignore = [ - name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool - ] - - ################ - # Dataset - ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) - - ########## - # Training - ################ - trainer = DPOTrainer( - model, - ref_model, - args=training_args, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - processing_class=tokenizer, - peft_config=peft_config, - ) - - trainer.train() - - if training_args.eval_strategy != "no": - metrics = trainer.evaluate() - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) - - # Save and push to hub - trainer.save_model(training_args.output_dir) - if training_args.push_to_hub: - trainer.push_to_hub(dataset_name=script_args.dataset_name) +############################################################################################### +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py # +############################################################################################### diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index e093aa4d9d..38023b1459 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -77,9 +77,7 @@ else: ref_model = None processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - do_image_splitting=False, + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, do_image_splitting=False ) tokenizer = processor.tokenizer diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 7ae26931e9..d68c0358dd 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -16,7 +16,7 @@ Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. # Full training: -python examples/scripts/kto.py \ +python trl/scripts/kto.py \ --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 16 \ @@ -33,7 +33,7 @@ --logging_first_step # QLoRA: -python examples/scripts/kto.py \ +python trl/scripts/kto.py \ --dataset_name trl-lib/kto-mix-14k \ --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ --per_device_train_batch_size 8 \ diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 71430bc536..eb17f728ae 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -110,9 +110,7 @@ judge = None tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - padding_side="left", - trust_remote_code=model_args.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 9de5135635..2758c950c5 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -89,9 +89,7 @@ ) tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - padding_side="left", - trust_remote_code=model_args.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index d0cd399a89..353a1493e3 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -96,9 +96,7 @@ ) tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - padding_side="left", - trust_remote_code=model_args.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 95eff811d4..85c443b7ae 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -71,9 +71,7 @@ # Model & Tokenizer ################ tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - padding_side="left", - trust_remote_code=model_args.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 6ac7a6e86c..cf4265e921 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -73,9 +73,7 @@ # Model & Tokenizer ################ tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - padding_side="left", - trust_remote_code=model_args.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 4a73268977..4b43634d47 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -12,101 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -# Full training -python examples/scripts/sft.py \ - --model_name_or_path Qwen/Qwen2-0.5B \ - --dataset_name trl-lib/Capybara \ - --learning_rate 2.0e-5 \ - --num_train_epochs 1 \ - --packing \ - --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 8 \ - --gradient_checkpointing \ - --logging_steps 25 \ - --eval_strategy steps \ - --eval_steps 100 \ - --output_dir Qwen2-0.5B-SFT \ - --push_to_hub - -# LoRA -python examples/scripts/sft.py \ - --model_name_or_path Qwen/Qwen2-0.5B \ - --dataset_name trl-lib/Capybara \ - --learning_rate 2.0e-4 \ - --num_train_epochs 1 \ - --packing \ - --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 8 \ - --gradient_checkpointing \ - --logging_steps 25 \ - --eval_strategy steps \ - --eval_steps 100 \ - --use_peft \ - --lora_r 32 \ - --lora_alpha 16 \ - --output_dir Qwen2-0.5B-SFT \ - --push_to_hub -""" - -from datasets import load_dataset -from transformers import AutoTokenizer - -from trl import ( - ModelConfig, - ScriptArguments, - SFTConfig, - SFTTrainer, - TrlParser, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) - - -if __name__ == "__main__": - parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_args = parser.parse_args_and_config() - - ################ - # Model init kwargs & Tokenizer - ################ - quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - torch_dtype=model_args.torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - training_args.model_init_kwargs = model_kwargs - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True - ) - tokenizer.pad_token = tokenizer.eos_token - - ################ - # Dataset - ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) - - ################ - # Training - ################ - trainer = SFTTrainer( - model=model_args.model_name_or_path, - args=training_args, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - processing_class=tokenizer, - peft_config=get_peft_config(model_args), - ) - - trainer.train() - - # Save and push to hub - trainer.save_model(training_args.output_dir) - if training_args.push_to_hub: - trainer.push_to_hub(dataset_name=script_args.dataset_name) +############################################################################################### +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py # +############################################################################################### diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index b30241cb02..726b457b2e 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -95,9 +95,7 @@ judge = None tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - padding_side="left", - trust_remote_code=model_args.trust_remote_code, + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/setup.py b/setup.py index 01b9fdb485..28d483e0c8 100644 --- a/setup.py +++ b/setup.py @@ -68,8 +68,6 @@ Then push the change with a message 'set dev version' """ -import os - from setuptools import find_packages, setup @@ -99,44 +97,41 @@ for reqs in EXTRAS.values(): EXTRAS["dev"].extend(reqs) -try: - file_path = os.path.dirname(os.path.abspath(__file__)) - os.symlink(os.path.join(file_path, "examples/scripts"), os.path.join(file_path, "trl/commands/scripts")) - - setup( - name="trl", - license="Apache 2.0", - classifiers=[ - "Development Status :: 2 - Pre-Alpha", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Natural Language :: English", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - url="https://github.com/huggingface/trl", - entry_points={ - "console_scripts": ["trl=trl.commands.cli:main"], - }, - include_package_data=True, - package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md"]}, - packages=find_packages(exclude={"tests"}), - install_requires=REQUIRED_PKGS, - extras_require=EXTRAS, - python_requires=">=3.9", - long_description=open("README.md", encoding="utf-8").read(), - long_description_content_type="text/markdown", - zip_safe=False, - version=__version__, - description="Train transformer language models with reinforcement learning.", - keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf", - author="Leandro von Werra", - author_email="leandro.vonwerra@gmail.com", - ) -finally: - os.unlink(os.path.join(file_path, "trl/commands/scripts")) + +setup( + name="trl", + license="Apache 2.0", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + url="https://github.com/huggingface/trl", + entry_points={ + "console_scripts": ["trl=trl.cli:main"], + }, + include_package_data=True, + package_data={ + "trl": ["templates/*.md"], + }, + packages=find_packages(exclude={"tests", "tests.slow"}), + install_requires=REQUIRED_PKGS, + extras_require=EXTRAS, + python_requires=">=3.9", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + zip_safe=False, + version=__version__, + description="Train transformer language models with reinforcement learning.", + keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf", + author="Leandro von Werra", + author_email="leandro.vonwerra@gmail.com", +) diff --git a/tests/test_cli.py b/tests/test_cli.py index 8330f5d5f8..f7719ecc16 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,34 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import sys + +import tempfile import unittest +from io import StringIO +from unittest.mock import patch + +from trl.cli import main + + +class TestCLI(unittest.TestCase): + def test_dpo(self): + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_preference --report_to none" + with patch("sys.argv", command.split(" ")): + main() + + @patch("sys.stdout", new_callable=StringIO) + def test_env(self, mock_stdout): + command = "trl env" + with patch("sys.argv", command.split(" ")): + main() + self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) + + def test_kto(self): + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl kto --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_unpaired_preference --report_to none" + with patch("sys.argv", command.split(" ")): + main() + + def test_sft(self): + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_language_modeling --report_to none" + with patch("sys.argv", command.split(" ")): + main() -class CLITester(unittest.TestCase): - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") - def test_sft_cli(self): - try: - subprocess.run( - "trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine", - shell=True, - check=True, - ) - except BaseException: - self.fail("An error occurred while running the CLI, please double check") - - @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") - def test_dpo_cli(self): - try: - subprocess.run( - "trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine", - shell=True, - check=True, - ) - except BaseException: - self.fail("An error occurred while running the CLI, please double check") - - def test_env_cli(self): - output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True) - self.assertIn("- Python version: ", output.stdout) +if __name__ == "__main__": + unittest.main() diff --git a/trl/__init__.py b/trl/__init__.py index b05f75cd8b..a27b00305a 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -20,7 +20,7 @@ _import_structure = { - "commands.cli_utils": ["DPOScriptArguments", "SFTScriptArguments", "TrlParser", "init_zero_verbose"], + "scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"], "core": ["set_seed"], "data_utils": [ "apply_chat_template", @@ -96,7 +96,6 @@ ], "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], - "utils": ["ScriptArguments"], } try: @@ -116,7 +115,6 @@ _import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"]) if TYPE_CHECKING: - from .commands.cli_utils import DPOScriptArguments, SFTScriptArguments, TrlParser, init_zero_verbose from .core import set_seed from .data_utils import ( apply_chat_template, @@ -138,6 +136,7 @@ create_reference_model, setup_chat_format, ) + from .scripts import ScriptArguments, TrlParser, init_zero_verbose from .trainer import ( AlignPropConfig, AlignPropTrainer, @@ -188,7 +187,6 @@ ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config - from .utils import ScriptArguments try: if not is_diffusers_available(): diff --git a/trl/cli.py b/trl/cli.py new file mode 100644 index 0000000000..d5a1421c51 --- /dev/null +++ b/trl/cli.py @@ -0,0 +1,81 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +from accelerate.commands.launch import launch_command, launch_command_parser + +from .scripts.chat import main as chat_main +from .scripts.chat import make_parser as make_chat_parser +from .scripts.dpo import make_parser as make_dpo_parser +from .scripts.env import print_env +from .scripts.kto import make_parser as make_kto_parser +from .scripts.sft import make_parser as make_sft_parser +from .scripts.utils import TrlParser + + +def main(): + parser = TrlParser(prog="TRL CLI", usage="trl", allow_abbrev=False) + + # Add the subparsers + subparsers = parser.add_subparsers(help="available commands", dest="command", parser_class=TrlParser) + + # Add the subparsers for every script + make_chat_parser(subparsers) + make_dpo_parser(subparsers) + subparsers.add_parser("env", help="Print the environment information") + make_kto_parser(subparsers) + make_sft_parser(subparsers) + + # Parse the arguments + args = parser.parse_args() + + if args.command == "chat": + (chat_args,) = parser.parse_args_and_config() + chat_main(chat_args) + + if args.command == "dpo": + # Get the default args for the launch command + dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py") + args = launch_command_parser().parse_args([dpo_training_script]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "dpo" + launch_command(args) # launch training + + elif args.command == "env": + print_env() + + elif args.command == "kto": + # Get the default args for the launch command + kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py") + args = launch_command_parser().parse_args([kto_training_script]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "kto" + launch_command(args) # launch training + + elif args.command == "sft": + # Get the default args for the launch command + sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py") + args = launch_command_parser().parse_args([sft_training_script]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "sft" + launch_command(args) # launch training + + +if __name__ == "__main__": + main() diff --git a/trl/commands/__init__.py b/trl/scripts/__init__.py similarity index 73% rename from trl/commands/__init__.py rename to trl/scripts/__init__.py index c4da312cd9..2994ed5275 100644 --- a/trl/commands/__init__.py +++ b/trl/scripts/__init__.py @@ -14,15 +14,15 @@ from typing import TYPE_CHECKING -from ..import_utils import OptionalDependencyNotAvailable, _LazyModule +from ..import_utils import _LazyModule _import_structure = { - "cli_utils": ["DPOScriptArguments", "SFTScriptArguments", "TrlParser", "YamlConfigParser", "init_zero_verbose"], + "utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"], } if TYPE_CHECKING: - from .cli_utils import DPOScriptArguments, SFTScriptArguments, TrlParser, YamlConfigParser, init_zero_verbose + from .utils import ScriptArguments, TrlParser, init_zero_verbose else: import sys diff --git a/trl/scripts/chat.py b/trl/scripts/chat.py new file mode 100644 index 0000000000..fa9eebc44e --- /dev/null +++ b/trl/scripts/chat.py @@ -0,0 +1,460 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import copy +import json +import os +import platform +import re +import time +from dataclasses import dataclass, field +from threading import Thread + +import torch +import yaml +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +from trl import TrlParser, init_zero_verbose +from trl.trainer.utils import get_quantization_config + + +if platform.system() != "Windows": + import pwd + +init_zero_verbose() + +HELP_STRING = """\ + +**TRL CHAT INTERFACE** + +The chat interface is a simple tool to try out a chat model. + +Besides talking to the model there are several commands: +- **clear**: clears the current conversation and start a new one +- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input +- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';'). +- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set** +- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided +- **exit**: closes the interface +""" + +SUPPORTED_GENERATION_KWARGS = [ + "max_new_tokens", + "do_sample", + "num_beams", + "temperature", + "top_p", + "top_k", + "repetition_penalty", +] + +SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$" + + +DEFAULT_EXAMPLES = { + "llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"}, + "code": { + "text": "Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]." + }, + "helicopter": {"text": "How many helicopters can a human eat in one sitting?"}, + "numbers": {"text": "Count to 10 but skip every number ending with an 'e'"}, + "birds": {"text": "Why aren't birds real?"}, + "socks": {"text": "Why is it important to eat socks after meditating?"}, +} + + +@dataclass +class ChatArguments: + # general settings + model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"}) + user: str = field(default=None, metadata={"help": "Username to display in chat interface"}) + system_prompt: str = field(default=None, metadata={"help": "System prompt"}) + save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"}) + device: str = field( + default="cpu", + metadata={"help": "device to use for inference."}, + ) + examples_path: str = field(default=None, metadata={"help": "Path to a yaml file with examples"}) + # generation settings + max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"}) + do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation"}) + num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search"}) + temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation"}) + top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"}) + top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"}) + repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"}) + eos_tokens: str = field( + default=None, + metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated"}, + ) + eos_token_ids: str = field( + default=None, + metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated"}, + ) + # model loading + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + torch_dtype: str = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) + attn_implementation: str = field( + default=None, + metadata={ + "help": ( + "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" + ) + }, + ) + load_in_8bit: bool = field( + default=False, + metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}, + ) + load_in_4bit: bool = field( + default=False, + metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}, + ) + + bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) + use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) + + +class RichInterface: + def __init__(self, model_name=None, user_name=None): + self._console = Console() + if model_name is None: + self.model_name = "assistant" + else: + self.model_name = model_name + if user_name is None: + self.user_name = "user" + else: + self.user_name = user_name + + def stream_output(self, output_stream): + """Stream output from a role.""" + # This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py + # Create a Live context for updating the console output + text = "" + self._console.print(f"[bold blue]<{self.model_name}>:") + with Live(console=self._console, refresh_per_second=4) as live: + # Read lines from the stream + for i, outputs in enumerate(output_stream): + if not outputs or i == 0: + continue + text += outputs + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines).strip(), code_theme="github-dark") + # Update the Live console output + live.update(markdown) + self._console.print() + return text + + def input(self): + input = self._console.input(f"[bold red]<{self.user_name}>:\n") + self._console.print() + return input + + def clear(self): + self._console.clear() + + def print_user_message(self, text): + self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}") + self._console.print() + + def print_green(self, text): + self._console.print(f"[bold green]{text}") + self._console.print() + + def print_red(self, text): + self._console.print(f"[bold red]{text}") + self._console.print() + + def print_help(self): + self._console.print(Markdown(HELP_STRING)) + self._console.print() + + +def get_username(): + if platform.system() == "Windows": + return os.getlogin() + else: + return pwd.getpwuid(os.getuid()).pw_name + + +def create_default_filename(model_name): + time_str = time.strftime("%Y-%m-%d_%H-%M-%S") + return f"{model_name}/chat_{time_str}.json" + + +def save_chat(chat, args, filename): + output_dict = {} + output_dict["settings"] = vars(args) + output_dict["chat_history"] = chat + + folder = args.save_folder + + if filename is None: + filename = create_default_filename(args.model_name_or_path) + filename = os.path.join(folder, filename) + os.makedirs(os.path.dirname(filename), exist_ok=True) + + with open(filename, "w") as f: + json.dump(output_dict, f, indent=4) + return os.path.abspath(filename) + + +def clear_chat_history(system_prompt): + if system_prompt is None: + chat = [] + else: + chat = [{"role": "system", "content": system_prompt}] + return chat + + +def parse_settings(user_input, current_args, interface): + settings = user_input[4:].strip().split(";") + settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings] + settings = dict(settings) + error = False + + for name in settings: + if hasattr(current_args, name): + try: + if isinstance(getattr(current_args, name), bool): + if settings[name] == "True": + settings[name] = True + elif settings[name] == "False": + settings[name] = False + else: + raise ValueError + else: + settings[name] = type(getattr(current_args, name))(settings[name]) + except ValueError: + interface.print_red( + f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}." + ) + else: + interface.print_red(f"There is no '{name}' setting.") + + if error: + interface.print_red("There was an issue parsing the settings. No settings have been changed.") + return current_args, False + else: + for name in settings: + setattr(current_args, name, settings[name]) + interface.print_green(f"Set {name} to {settings[name]}.") + + time.sleep(1.5) # so the user has time to read the changes + return current_args, True + + +def load_model_and_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, + revision=args.model_revision, + trust_remote_code=args.trust_remote_code, + ) + + torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype) + quantization_config = get_quantization_config(args) + model_kwargs = dict( + revision=args.model_revision, + attn_implementation=args.attn_implementation, + torch_dtype=torch_dtype, + device_map="auto", + quantization_config=quantization_config, + ) + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs + ) + + if getattr(model, "hf_device_map", None) is None: + model = model.to(args.device) + + return model, tokenizer + + +def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids): + if tokenizer.pad_token_id is None: + pad_token_id = tokenizer.eos_token_id + else: + pad_token_id = tokenizer.pad_token_id + + all_eos_token_ids = [] + + if eos_tokens is not None: + all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(","))) + + if eos_token_ids is not None: + all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")]) + + if len(all_eos_token_ids) == 0: + all_eos_token_ids.append(tokenizer.eos_token_id) + + return pad_token_id, all_eos_token_ids + + +def main(args: ChatArguments): + if args.examples_path is None: + examples = DEFAULT_EXAMPLES + else: + with open(args.examples_path) as f: + examples = yaml.safe_load(f) + + current_args = copy.deepcopy(args) + + if args.user is None: + user = get_username() + else: + user = args.user + + model, tokenizer = load_model_and_tokenizer(args) + generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) + + pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) + + interface = RichInterface(model_name=args.model_name_or_path, user_name=user) + interface.clear() + chat = clear_chat_history(current_args.system_prompt) + while True: + try: + user_input = interface.input() + + if user_input == "clear": + chat = clear_chat_history(current_args.system_prompt) + interface.clear() + continue + + if user_input == "help": + interface.print_help() + continue + + if user_input == "exit": + break + + if user_input == "reset": + interface.clear() + current_args = copy.deepcopy(args) + chat = clear_chat_history(current_args.system_prompt) + continue + + if user_input.startswith("save") and len(user_input.split()) < 2: + split_input = user_input.split() + + if len(split_input) == 2: + filename = split_input[1] + else: + filename = None + filename = save_chat(chat, current_args, filename) + interface.print_green(f"Chat saved in {filename}!") + continue + + if re.match(SETTING_RE, user_input): + current_args, success = parse_settings(user_input, current_args, interface) + if success: + chat = [] + interface.clear() + continue + + if user_input.startswith("example") and len(user_input.split()) == 2: + example_name = user_input.split()[1] + if example_name in examples: + interface.clear() + chat = [] + interface.print_user_message(examples[example_name]["text"]) + user_input = examples[example_name]["text"] + else: + interface.print_red( + f"Example {example_name} not found in list of available examples: {list(examples.keys())}." + ) + continue + + chat.append({"role": "user", "content": user_input}) + + inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( + model.device + ) + attention_mask = torch.ones_like(inputs) + generation_kwargs = dict( + inputs=inputs, + attention_mask=attention_mask, + streamer=generation_streamer, + max_new_tokens=current_args.max_new_tokens, + do_sample=current_args.do_sample, + num_beams=current_args.num_beams, + temperature=current_args.temperature, + top_k=current_args.top_k, + top_p=current_args.top_p, + repetition_penalty=current_args.repetition_penalty, + pad_token_id=pad_token_id, + eos_token_id=eos_token_ids, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + model_output = interface.stream_output(generation_streamer) + thread.join() + chat.append({"role": "assistant", "content": model_output}) + + except KeyboardInterrupt: + break + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ChatArguments,) + if subparsers is not None: + parser = subparsers.add_parser("chat", help=HELP_STRING, dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + (chat_args,) = parser.parse_args_and_config() + main(chat_args) diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py new file mode 100644 index 0000000000..69b779e391 --- /dev/null +++ b/trl/scripts/dpo.py @@ -0,0 +1,151 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Full training +python trl/scripts/dpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-7 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-DPO \ + --no_remove_unused_columns + +# LoRA: +python trl/scripts/dpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-6 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-DPO \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +""" + +import argparse + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import ( + DPOConfig, + DPOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +def main(script_args, training_args, model_args): + ################ + # Model & Tokenizer + ################### + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + peft_config = get_peft_config(model_args) + if peft_config is None: + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + else: + ref_model = None + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ########## + # Training + ################ + trainer = DPOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=peft_config, + ) + + trainer.train() + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, DPOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/commands/cli.py b/trl/scripts/env.py similarity index 54% rename from trl/commands/cli.py rename to trl/scripts/env.py index 44a353acec..57432263c4 100644 --- a/trl/commands/cli.py +++ b/trl/scripts/env.py @@ -14,22 +14,16 @@ import os import platform -import subprocess -import sys from importlib.metadata import version -from subprocess import CalledProcessError import torch from accelerate.commands.config import default_config_file, load_config_from_file -from rich.console import Console from transformers import is_bitsandbytes_available from transformers.utils import is_liger_kernel_available, is_openai_available, is_peft_available -from .. import __version__, is_deepspeed_available, is_diffusers_available, is_llm_blender_available -from .cli_utils import get_git_commit_hash - - -SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto", "env"] +from .. import __version__ +from ..import_utils import is_deepspeed_available, is_diffusers_available, is_llm_blender_available +from .utils import get_git_commit_hash def print_env(): @@ -74,71 +68,5 @@ def print_env(): print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa -def train(command_name): - console = Console() - # Make sure to import things locally to avoid verbose from third party libs. - with console.status("[bold purple]Welcome! Initializing the TRL CLI..."): - from trl.commands.cli_utils import init_zero_verbose - - init_zero_verbose() - command_name = sys.argv[1] - trl_examples_dir = os.path.dirname(__file__) - - command = f"accelerate launch {trl_examples_dir}/scripts/{command_name}.py {' '.join(sys.argv[2:])}" - - try: - subprocess.run( - command.split(), - text=True, - check=True, - encoding="utf-8", - cwd=os.getcwd(), - env=os.environ.copy(), - ) - except (CalledProcessError, ChildProcessError) as exc: - console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.") - raise ValueError("TRL CLI failed! Check the traceback above..") from exc - - -def chat(): - console = Console() - # Make sure to import things locally to avoid verbose from third party libs. - with console.status("[bold purple]Welcome! Initializing the TRL CLI..."): - from trl.commands.cli_utils import init_zero_verbose - - init_zero_verbose() - trl_examples_dir = os.path.dirname(__file__) - - command = f"python {trl_examples_dir}/scripts/chat.py {' '.join(sys.argv[2:])}" - - try: - subprocess.run( - command.split(), - text=True, - check=True, - encoding="utf-8", - cwd=os.getcwd(), - env=os.environ.copy(), - ) - except (CalledProcessError, ChildProcessError) as exc: - console.log("TRL - CHAT failed! See the logs above for further details.") - raise ValueError("TRL CLI failed! Check the traceback above..") from exc - - -def main(): - command_name = sys.argv[1] - - if command_name in ["sft", "dpo", "kto"]: - train(command_name) - elif command_name == "chat": - chat() - elif command_name == "env": - print_env() - else: - raise ValueError( - f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}" - ) - - if __name__ == "__main__": - main() + print_env() diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py new file mode 100644 index 0000000000..9eb44ba09f --- /dev/null +++ b/trl/scripts/kto.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. + +# Full training: +python trl/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 16 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --bf16 \ + --logging_first_step + +# QLoRA: +python trl/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model-lora \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --use_peft \ + --load_in_4bit \ + --lora_target_modules=all-linear \ + --lora_r=16 \ + --lora_alpha=16 +""" + +import argparse + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import ( + KTOConfig, + KTOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_peft_config, + setup_chat_format, +) + + +def main(script_args, training_args, model_args): + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # If we are aligning a base model, we use ChatML as the default template + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer) + + # Load the dataset + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + # Initialize the KTO trainer + trainer = KTOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # Train and push the model to the Hub + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, KTOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py new file mode 100644 index 0000000000..f457b1dc68 --- /dev/null +++ b/trl/scripts/sft.py @@ -0,0 +1,126 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Full training +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 100 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub + +# LoRA +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-4 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 100 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub +""" + +import argparse + +from datasets import load_dataset +from transformers import AutoTokenizer + +from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +def main(script_args, training_args, model_args): + ################ + # Model init kwargs & Tokenizer + ################ + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + training_args.model_init_kwargs = model_kwargs + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True + ) + tokenizer.pad_token = tokenizer.eos_token + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model_args.model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, SFTConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/commands/cli_utils.py b/trl/scripts/utils.py similarity index 72% rename from trl/commands/cli_utils.py rename to trl/scripts/utils.py index d899011a6f..a6637c02a3 100644 --- a/trl/commands/cli_utils.py +++ b/trl/scripts/utils.py @@ -19,7 +19,7 @@ import subprocess import sys import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Iterable, Optional, Union import yaml @@ -31,6 +31,35 @@ logger = logging.getLogger(__name__) +@dataclass +class ScriptArguments: + """ + Arguments common to all scripts. + + Args: + dataset_name (`str`): + Dataset name. + dataset_config (`str` or `None`, *optional*, defaults to `None`): + Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. + dataset_train_split (`str`, *optional*, defaults to `"train"`): + Dataset split to use for training. + dataset_test_split (`str`, *optional*, defaults to `"test"`): + Dataset split to use for evaluation. + gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): + Whether to apply `use_reentrant` for gradient_checkpointing. + ignore_bias_buffers (`bool`, *optional*, defaults to `False`): + Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar + type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. + """ + + dataset_name: str + dataset_config: Optional[str] = None + dataset_train_split: str = "train" + dataset_test_split: str = "test" + gradient_checkpointing_use_reentrant: bool = False + ignore_bias_buffers: bool = False + + class YamlConfigParser: """ """ @@ -90,84 +119,13 @@ def warning_handler(message, category, filename, lineno, file=None, line=None): warnings.showwarning = warning_handler -@dataclass -class ChatArguments: - # general settings - model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"}) - user: str = field(default=None, metadata={"help": "Username to display in chat interface"}) - system_prompt: str = field(default=None, metadata={"help": "System prompt"}) - save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"}) - device: str = field( - default="cpu", - metadata={"help": "device to use for inference."}, - ) - config: str = field( - default="default", - metadata={ - "help": "Config file used for setting the configs. If `default` uses examples/scripts/config/default_chat_config.yaml" - }, - ) - examples: str = field(default=None, metadata={"help": "Empty placeholder needs to be set via config."}) - # generation settings - max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"}) - do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation"}) - num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search"}) - temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation"}) - top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"}) - top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"}) - repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"}) - eos_tokens: str = field( - default=None, - metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated"}, - ) - eos_token_ids: str = field( - default=None, - metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated"}, - ) - # model loading - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - torch_dtype: str = field( - default=None, - metadata={ - "help": ( - "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " - "dtype will be automatically derived from the model's weights." - ), - "choices": ["auto", "bfloat16", "float16", "float32"], - }, - ) - trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) - attn_implementation: str = field( - default=None, - metadata={ - "help": ( - "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" - ) - }, - ) - load_in_8bit: bool = field( - default=False, - metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}, - ) - load_in_4bit: bool = field( - default=False, - metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}, - ) - - bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) - use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) - - class TrlParser(HfArgumentParser): """ A subclass of [`transformers.HfArgumentParser`] designed for parsing command-line arguments with dataclass-backed configurations, while also supporting configuration file loading and environment variable management. Args: - dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`): + dataclass_types (`Union[DataClassType, Iterable[DataClassType]]` or `None`, *optional*, defaults to `None`): Dataclass types to use for argument parsing. **kwargs: Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor. @@ -215,12 +173,15 @@ class MyArguments: ) def __init__( self, - dataclass_types: Union[DataClassType, Iterable[DataClassType]], + dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, ignore_extra_args: Optional[bool] = None, **kwargs, ): - super().__init__(dataclass_types=dataclass_types, **kwargs) - self._ignore_extra_args = ignore_extra_args + # Make sure dataclass_types is an iterable + if dataclass_types is None: + dataclass_types = [] + elif not isinstance(dataclass_types, Iterable): + dataclass_types = [dataclass_types] # Check that none of the dataclasses have the "config" field for dataclass_type in dataclass_types: @@ -230,6 +191,9 @@ def __init__( f"config file path and should not be used in the dataclass." ) + super().__init__(dataclass_types=dataclass_types, **kwargs) + self._ignore_extra_args = ignore_extra_args + def post_process_dataclasses(self, dataclasses): """ Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments. diff --git a/trl/utils.py b/trl/utils.py deleted file mode 100644 index 67b8db9488..0000000000 --- a/trl/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class ScriptArguments: - """ - Arguments common to all scripts. - - Args: - dataset_name (`str`): - Dataset name. - dataset_config (`str` or `None`, *optional*, defaults to `None`): - Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. - dataset_train_split (`str`, *optional*, defaults to `"train"`): - Dataset split to use for training. - dataset_test_split (`str`, *optional*, defaults to `"test"`): - Dataset split to use for evaluation. - gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): - Whether to apply `use_reentrant` for gradient_checkpointing. - ignore_bias_buffers (`bool`, *optional*, defaults to `False`): - Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar - type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. - """ - - dataset_name: str - dataset_config: Optional[str] = None - dataset_train_split: str = "train" - dataset_test_split: str = "test" - gradient_checkpointing_use_reentrant: bool = False - ignore_bias_buffers: bool = False