Skip to content

Commit

Permalink
precompute packed iterable into a dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
leandro committed Nov 10, 2023
1 parent c2884b5 commit 646094f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,17 +323,18 @@ def _prepare_dataset(
"You need to pass a tokenizer when using the SFT Trainer when passing a `dataset_text_field`."
)

return ConstantLengthDataset(
constant_length_iterator = ConstantLengthDataset(
tokenizer,
dataset,
dataset_text_field=dataset_text_field,
formatting_func=formatting_func,
seq_length=max_seq_length,
infinite=infinite,
infinite=False,
num_of_sequences=num_of_sequences,
chars_per_token=chars_per_token,
eos_token_id=tokenizer.eos_token_id,
)
return Dataset.from_generator(constant_length_iterator)

raise ValueError(
"You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`."
Expand Down

0 comments on commit 646094f

Please sign in to comment.