Skip to content

Commit

Permalink
🆔 Add datast_config to ScriptArguments (#2440)
Browse files Browse the repository at this point in the history
* datast_config_name

* Update trl/utils.py [ci skip]

* sort import

* typo [ci skip]

* Trigger CI

* Rename `dataset_config_name` to `dataset_config`
  • Loading branch information
qgallouedec authored Dec 10, 2024
1 parent 2f72f47 commit 6a05fef
Show file tree
Hide file tree
Showing 20 changed files with 27 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def mean_pooling(model_output, attention_mask):
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)

dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

accelerator = Accelerator()
embedding_model = AutoModel.from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

##########
# Training
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

trainer = OnlineDPOTrainer(
model=model,
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

################
# Training
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

with PartialState().local_main_process_first():
dataset = dataset.map(
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
model, tokenizer = setup_chat_format(model, tokenizer)

# Load the dataset
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

# Initialize the KTO trainer
trainer = KTOTrainer(
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

trainer = NashMDTrainer(
model=model,
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

Expand Down
4 changes: 3 additions & 1 deletion examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split)
dataset = load_dataset(
script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split
)
eval_samples = 100
train_dataset = dataset.select(range(len(dataset) - eval_samples))
eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset)))
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
train_dataset = dataset[script_args.dataset_train_split]
eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
##############
# Load dataset
##############
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

##########
# Training
Expand Down
4 changes: 3 additions & 1 deletion examples/scripts/rloo/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split)
dataset = load_dataset(
script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split
)
eval_samples = 100
train_dataset = dataset.select(range(len(dataset) - eval_samples))
eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset)))
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/rloo/rloo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
train_dataset = dataset[script_args.dataset_train_split]
eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

################
# Training
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class CustomScriptArguments(ScriptArguments):
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

# Load dataset
dataset = load_dataset(script_args.dataset_name, split="train")
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train")

# Setup model
torch_dtype = (
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/sft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def collate_fn(examples):
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

################
# Training
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/sft_vlm_smol_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def collate_fn(examples):
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

################
# Training
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/xpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

dataset = load_dataset(script_args.dataset_name)
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

trainer = XPOTrainer(
model=model,
Expand Down
4 changes: 4 additions & 0 deletions trl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional


@dataclass
Expand All @@ -23,6 +24,8 @@ class ScriptArguments:
Args:
dataset_name (`str`):
Dataset name.
dataset_config (`str` or `None`, *optional*, defaults to `None`):
Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function.
dataset_train_split (`str`, *optional*, defaults to `"train"`):
Dataset split to use for training.
dataset_test_split (`str`, *optional*, defaults to `"test"`):
Expand All @@ -35,6 +38,7 @@ class ScriptArguments:
"""

dataset_name: str
dataset_config: Optional[str] = None
dataset_train_split: str = "train"
dataset_test_split: str = "test"
gradient_checkpointing_use_reentrant: bool = False
Expand Down

0 comments on commit 6a05fef

Please sign in to comment.