diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 87ba1d166c..5eafbdd33c 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -323,17 +323,18 @@ def _prepare_dataset( "You need to pass a tokenizer when using the SFT Trainer when passing a `dataset_text_field`." ) - return ConstantLengthDataset( + constant_length_iterator = ConstantLengthDataset( tokenizer, dataset, dataset_text_field=dataset_text_field, formatting_func=formatting_func, seq_length=max_seq_length, - infinite=infinite, + infinite=False, num_of_sequences=num_of_sequences, chars_per_token=chars_per_token, eos_token_id=tokenizer.eos_token_id, ) + return Dataset.from_generator(constant_length_iterator) raise ValueError( "You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`."