Skip to content

Commit

Permalink
prepare_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Dec 19, 2024
1 parent 3e67ccb commit 3938a52
Showing 1 changed file with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset, concatenate_datasets
from datasets import Dataset, IterableDataset, concatenate_datasets
from packaging import version
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
Expand Down Expand Up @@ -513,6 +513,19 @@ def make_inputs_require_grad(module, input, output):
# issued.
model.warnings_issued["estimate_tokens"] = True

# 4. Handle the dataset - UNCOMMENT WHEN _prepare_dataset READY
# preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
# if preprocess_dataset:
# train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
# if eval_dataset is not None:
# if isinstance(eval_dataset, dict):
# eval_dataset = {
# key: self._prepare_dataset(dataset, processing_class, args, key)
# for key, dataset in eval_dataset.items()
# }
# else:
# eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
Expand Down Expand Up @@ -662,7 +675,7 @@ def make_inputs_require_grad(module, input, output):
UserWarning,
)

train_dataset= train_dataset.remove_columns(
train_dataset = train_dataset.remove_columns(
[
"prompt",
"completion",
Expand Down Expand Up @@ -712,6 +725,43 @@ def make_inputs_require_grad(module, input, output):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
args: KTOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs = {}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc

with PartialState().local_main_process_first():
# Extract prompt if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)

# Unpair the dataset if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Unpairing {dataset_name} dataset"
dataset = maybe_unpair_preference_dataset(dataset, **map_kwargs)

# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs)

# HERE

# # Tokenize the dataset
# if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
# map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
# dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)

return dataset

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down

0 comments on commit 3938a52

Please sign in to comment.