From 646094f463ae09084708efa5dcdc361b996292ac Mon Sep 17 00:00:00 2001 From: leandro Date: Fri, 10 Nov 2023 15:48:44 +0100 Subject: [PATCH] precompute packed iterable into a dataset --- trl/trainer/sft_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 87ba1d166c..5eafbdd33c 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -323,17 +323,18 @@ def _prepare_dataset( "You need to pass a tokenizer when using the SFT Trainer when passing a `dataset_text_field`." ) - return ConstantLengthDataset( + constant_length_iterator = ConstantLengthDataset( tokenizer, dataset, dataset_text_field=dataset_text_field, formatting_func=formatting_func, seq_length=max_seq_length, - infinite=infinite, + infinite=False, num_of_sequences=num_of_sequences, chars_per_token=chars_per_token, eos_token_id=tokenizer.eos_token_id, ) + return Dataset.from_generator(constant_length_iterator) raise ValueError( "You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`."