From 6a05feff02a0e8f8f93819da02ef8a14060c8662 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:09:26 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=86=94=20Add=20`datast=5Fconfig`=20to=20`?= =?UTF-8?q?ScriptArguments`=20(#2440)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * datast_config_name * Update trl/utils.py [ci skip] * sort import * typo [ci skip] * Trigger CI * Rename `dataset_config_name` to `dataset_config` --- examples/scripts/bco.py | 2 +- examples/scripts/cpo.py | 2 +- examples/scripts/dpo.py | 2 +- examples/scripts/dpo_online.py | 2 +- examples/scripts/dpo_vlm.py | 2 +- examples/scripts/gkd.py | 2 +- examples/scripts/kto.py | 2 +- examples/scripts/nash_md.py | 2 +- examples/scripts/orpo.py | 2 +- examples/scripts/ppo/ppo.py | 4 +++- examples/scripts/ppo/ppo_tldr.py | 2 +- examples/scripts/reward_modeling.py | 2 +- examples/scripts/rloo/rloo.py | 4 +++- examples/scripts/rloo/rloo_tldr.py | 2 +- examples/scripts/sft.py | 2 +- examples/scripts/sft_video_llm.py | 2 +- examples/scripts/sft_vlm.py | 2 +- examples/scripts/sft_vlm_smol_vlm.py | 2 +- examples/scripts/xpo.py | 2 +- trl/utils.py | 4 ++++ 20 files changed, 27 insertions(+), 19 deletions(-) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index b4302148be..38a5d35a38 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -126,7 +126,7 @@ def mean_pooling(model_output, attention_mask): if tokenizer.chat_template is None: model, tokenizer = setup_chat_format(model, tokenizer) - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) accelerator = Accelerator() embedding_model = AutoModel.from_pretrained( diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 5dfb39dcd4..20ea85c925 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -81,7 +81,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 9552a0ae08..cbaba95a1a 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -111,7 +111,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ########## # Training diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 6ea730c9fc..4859056259 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -121,7 +121,7 @@ if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) trainer = OnlineDPOTrainer( model=model, diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 2b19e35fcc..a58dfc4152 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -104,7 +104,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index a89327d963..bac694b9be 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -104,7 +104,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) with PartialState().local_main_process_first(): dataset = dataset.map( diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index e6a09fc539..7ae26931e9 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -91,7 +91,7 @@ model, tokenizer = setup_chat_format(model, tokenizer) # Load the dataset - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) # Initialize the KTO trainer trainer = KTOTrainer( diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 0fc80fa43e..bde51b32e3 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -121,7 +121,7 @@ if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) trainer = NashMDTrainer( model=model, diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index 886db53d51..2d0fefd494 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -81,7 +81,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index d43c91b3d7..05b5870dae 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -119,7 +119,9 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split + ) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index b7c19a25e3..fabacf5b98 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -126,7 +126,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) train_dataset = dataset[script_args.dataset_train_split] eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 766c7e44f8..3a2e311800 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -107,7 +107,7 @@ ############## # Load dataset ############## - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ########## # Training diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 4f6732b2c6..6a56aac8c8 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -90,7 +90,9 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split + ) eval_samples = 100 train_dataset = dataset.select(range(len(dataset) - eval_samples)) eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 56ea24080d..8e89570963 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -92,7 +92,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) train_dataset = dataset[script_args.dataset_train_split] eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index d68356798b..751be63771 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -90,7 +90,7 @@ ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index 8cb8864e44..2a3f08812b 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -166,7 +166,7 @@ class CustomScriptArguments(ScriptArguments): training_args.dataset_kwargs = {"skip_prepare_dataset": True} # Load dataset - dataset = load_dataset(script_args.dataset_name, split="train") + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train") # Setup model torch_dtype = ( diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index 330c67b50d..ca17ec8a09 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -109,7 +109,7 @@ def collate_fn(examples): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py index ad95253edc..eb08a8d7da 100644 --- a/examples/scripts/sft_vlm_smol_vlm.py +++ b/examples/scripts/sft_vlm_smol_vlm.py @@ -121,7 +121,7 @@ def collate_fn(examples): ################ # Dataset ################ - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index ea57322a61..2ddc532374 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -106,7 +106,7 @@ if tokenizer.chat_template is None: tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE - dataset = load_dataset(script_args.dataset_name) + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) trainer = XPOTrainer( model=model, diff --git a/trl/utils.py b/trl/utils.py index b63742d2ac..67b8db9488 100644 --- a/trl/utils.py +++ b/trl/utils.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass +from typing import Optional @dataclass @@ -23,6 +24,8 @@ class ScriptArguments: Args: dataset_name (`str`): Dataset name. + dataset_config (`str` or `None`, *optional*, defaults to `None`): + Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. dataset_train_split (`str`, *optional*, defaults to `"train"`): Dataset split to use for training. dataset_test_split (`str`, *optional*, defaults to `"test"`): @@ -35,6 +38,7 @@ class ScriptArguments: """ dataset_name: str + dataset_config: Optional[str] = None dataset_train_split: str = "train" dataset_test_split: str = "test" gradient_checkpointing_use_reentrant: bool = False