Skip to content

Commit

Permalink
Add note about special tokens in chat templates for LoRA SFT (#2414)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun authored Nov 29, 2024
1 parent e1d7813 commit 2c6e0d9
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -331,33 +331,38 @@ Note that all keyword arguments of `from_pretrained()` are supported.

### Training adapters

We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model.

```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

dataset = load_dataset("stanfordnlp/imdb", split="train")
dataset = load_dataset("trl-lib/Capybara", split="train")

peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules="all-linear",
modules_to_save=["lm_head", "embed_token"],
task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
"EleutherAI/gpt-neo-125m",
"Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
args=SFTConfig(output_dir="Qwen2.5-0.5B-SFT"),
peft_config=peft_config
)

trainer.train()
```

> [!WARNING]
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.

You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.

### Training adapters with base 8 bit models
Expand Down

0 comments on commit 2c6e0d9

Please sign in to comment.