Skip to content

Commit

Permalink
🎭 Deprecate [SFT/DPO/Reward]ScriptArguments in favour of `ScriptArg…
Browse files Browse the repository at this point in the history
…uments` (#2145)

* `DPOScriptArguments` to `ScriptArguments`

* use dataset_train_split

* Use scriptarguments

* dataset names in command lines

* use `ScriptArguments` everywhere

* ignore biais buffer to end

* remove in v0.13

* rm comment

* update test commands

* Update docs/source/rloo_trainer.md

* Update tests/test_rloo_trainer.py

* Added dataset_train_split argument to ppo.py and rloo.py

* update scripts with dataset_train_split
  • Loading branch information
qgallouedec authored Oct 14, 2024
1 parent 14f3613 commit 7e394b0
Show file tree
Hide file tree
Showing 24 changed files with 165 additions and 131 deletions.
2 changes: 2 additions & 0 deletions docs/source/ppo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ To just run a PPO script to make sure the trainer can run, you can run the follo

```bash
python examples/scripts/ppo/ppo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand Down
5 changes: 4 additions & 1 deletion docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ To just run a RLOO script to make sure the trainer can run, you can run the foll

```bash
python examples/scripts/rloo/rloo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 64 \
Expand Down Expand Up @@ -210,8 +212,9 @@ To validate the RLOO implementation works, we ran experiment on the 1B model. He

```
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/rloo/rloo_tldr.py \
--output_dir models/minimal/rloo_tldr \
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
--dataset_test_split validation \
--num_ppo_epochs 2 \
--num_mini_batches 2 \
--learning_rate 3e-6 \
Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from datasets import load_dataset
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel

from trl import BCOConfig, BCOTrainer, DPOScriptArguments, ModelConfig, get_peft_config, setup_chat_format
from trl import BCOConfig, BCOTrainer, ModelConfig, ScriptArguments, get_peft_config, setup_chat_format


def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.LongTensor, model: PreTrainedModel):
Expand All @@ -103,7 +103,7 @@ def mean_pooling(model_output, attention_mask):


if __name__ == "__main__":
parser = HfArgumentParser((DPOScriptArguments, BCOConfig, ModelConfig))
parser = HfArgumentParser((ScriptArguments, BCOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()

training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
Expand Down Expand Up @@ -150,8 +150,8 @@ def mean_pooling(model_output, attention_mask):
model,
ref_model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
embedding_func=embedding_func,
Expand Down
18 changes: 5 additions & 13 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# regular:
python examples/scripts/cpo.py \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
Expand All @@ -33,6 +34,7 @@
# peft:
python examples/scripts/cpo.py \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
Expand All @@ -52,23 +54,13 @@
--lora_alpha=16
"""

from dataclasses import dataclass, field

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config
from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


@dataclass
class ScriptArguments:
dataset_name: str = field(
default="trl-lib/ultrafeedback_binarized",
metadata={"help": "The name of the dataset to use."},
)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
Expand Down Expand Up @@ -98,8 +90,8 @@ class ScriptArguments:
trainer = CPOTrainer(
model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@

from trl import (
DPOConfig,
DPOScriptArguments,
DPOTrainer,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
Expand All @@ -64,7 +64,7 @@


if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()

################
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig

from trl import (
DPOScriptArguments,
LogCompletionsCallback,
ModelConfig,
OnlineDPOConfig,
OnlineDPOTrainer,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
Expand All @@ -58,7 +58,7 @@


if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, OnlineDPOConfig, ModelConfig))
parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@

from trl import (
DPOConfig,
DPOScriptArguments,
DPOTrainer,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
Expand All @@ -44,7 +44,7 @@


if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()

################
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
GKDTrainer,
LogCompletionsCallback,
ModelConfig,
SFTScriptArguments,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
Expand All @@ -62,7 +62,7 @@


if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, GKDConfig, ModelConfig))
parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()

################
Expand Down
28 changes: 13 additions & 15 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Full training:
python examples/scripts/kto.py \
--dataset_name trl-lib/kto-mix-14k \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 16 \
--num_train_epochs 1 \
Expand All @@ -33,6 +34,7 @@
# QLoRA:
python examples/scripts/kto.py \
--dataset_name trl-lib/kto-mix-14k \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
Expand All @@ -53,23 +55,19 @@
--lora_alpha=16
"""

from dataclasses import dataclass

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""

dataset_name: str = "trl-lib/kto-mix-14k"
from trl import (
KTOConfig,
KTOTrainer,
ModelConfig,
ScriptArguments,
get_peft_config,
maybe_unpair_preference_dataset,
setup_chat_format,
)


if __name__ == "__main__":
Expand Down Expand Up @@ -120,8 +118,8 @@ def format_dataset(example):
model,
ref_model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig

from trl import (
DPOScriptArguments,
LogCompletionsCallback,
ModelConfig,
NashMDConfig,
NashMDTrainer,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_quantization_config,
Expand All @@ -63,7 +63,7 @@


if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, NashMDConfig, ModelConfig))
parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

Expand Down
18 changes: 5 additions & 13 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# regular:
python examples/scripts/orpo.py \
--dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
Expand All @@ -33,6 +34,7 @@
# peft:
python examples/scripts/orpo.py \
--dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
Expand All @@ -52,23 +54,13 @@
--lora_alpha=16
"""

from dataclasses import dataclass, field

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config
from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


@dataclass
class ScriptArguments:
dataset_name: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
metadata={"help": "The name of the dataset to use."},
)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
Expand Down Expand Up @@ -98,8 +90,8 @@ class ScriptArguments:
trainer = ORPOTrainer(
model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
)
Expand Down
14 changes: 9 additions & 5 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
HfArgumentParser,
)

from trl import ModelConfig, PPOConfig, PPOTrainer
from trl import ModelConfig, PPOConfig, PPOTrainer, ScriptArguments
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


"""
python -i examples/scripts/ppo/ppo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
Expand All @@ -39,6 +41,8 @@
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/ppo/ppo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--output_dir models/minimal/ppo \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
Expand All @@ -55,8 +59,8 @@


if __name__ == "__main__":
parser = HfArgumentParser((PPOConfig, ModelConfig))
training_args, model_config = parser.parse_args_into_dataclasses()
parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(training_args.output_dir, ignore_errors=True)

Expand Down Expand Up @@ -86,7 +90,7 @@
################
# Dataset
################
dataset = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split)
eval_samples = 100
train_dataset = dataset.select(range(len(dataset) - eval_samples))
eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset)))
Expand Down Expand Up @@ -133,6 +137,6 @@ def tokenize(element):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style")
trainer.push_to_hub(dataset_name=script_args.dataset_name)

trainer.generate_completions()
Loading

0 comments on commit 7e394b0

Please sign in to comment.