Skip to content

Commit

Permalink
add eval_packing (#1369)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored Feb 27, 2024
1 parent 7c2213b commit 7712d42
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ trainer.train()
```

Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTTrainer` init method.

#### Customize your prompts using packed dataset

Expand Down
50 changes: 50 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,53 @@ def test_sft_trainer_tag(self):
)

assert trainer.model.model_tags == trainer._tag_names

def test_sft_trainer_eval_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
)

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
packing=True,
max_seq_length=32, # make sure there is at least 1 packed sequence
eval_packing=False,
)

assert len(trainer.train_dataset["input_ids"]) == 1
assert len(trainer.eval_dataset["input_ids"]) != 1

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
max_seq_length=32, # make sure there is at least 1 packed sequence
packing=True,
)

assert len(trainer.train_dataset["input_ids"]) == 1
assert len(trainer.eval_dataset["input_ids"]) == 1

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
max_seq_length=32, # make sure there is at least 1 packed sequence
packing=False,
)

assert len(trainer.train_dataset["input_ids"]) != 1
assert len(trainer.eval_dataset["input_ids"]) != 1
10 changes: 8 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class SFTTrainer(Trainer):
Dict of Optional kwargs to pass when instantiating the model from a string
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
eval_packing: (`Optional[bool]`, *optional*):
Whether to pack the eval dataset as well. Defaults to `packing` if `None` is passed.
"""

_tag_names = ["trl", "sft"]
Expand All @@ -124,7 +126,7 @@ def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
data_collator: Optional[DataCollator] = None, # type: ignore
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
Expand All @@ -146,6 +148,7 @@ def __init__(
neftune_noise_alpha: Optional[float] = None,
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
eval_packing: Optional[bool] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
Expand Down Expand Up @@ -274,11 +277,14 @@ def make_inputs_require_grad(module, input, output):
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}

eval_packing = packing if eval_packing is None else eval_packing

for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
_eval_dataset,
tokenizer,
packing,
eval_packing,
dataset_text_field,
max_seq_length,
formatting_func,
Expand Down

0 comments on commit 7712d42

Please sign in to comment.