Skip to content

Commit

Permalink
Fix packing test (#2111)
Browse files Browse the repository at this point in the history
* Fix pack test

* same for eval
  • Loading branch information
qgallouedec authored Sep 24, 2024
1 parent 80038a5 commit a84fc5d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def test_sft_trainer_only_train_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)

assert len(trainer.train_dataset["input_ids"]) == 16 # with the used dataset, we end up with 16 sequences
assert len(trainer.train_dataset["input_ids"]) == 21 # with the used dataset, we end up with 21 sequences
assert len(trainer.eval_dataset["input_ids"]) == len(self.conversational_lm_dataset["test"])

def test_sft_trainer_eval_packing(self):
Expand All @@ -1131,8 +1131,8 @@ def test_sft_trainer_eval_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)

assert len(trainer.train_dataset["input_ids"]) == 16 # with the used dataset, we end up with 16 sequences
assert len(trainer.eval_dataset["input_ids"]) == 1 # with the used dataset, we end up with 1 sequence
assert len(trainer.train_dataset["input_ids"]) == 21 # with the used dataset, we end up with 21 sequences
assert len(trainer.eval_dataset["input_ids"]) == 2 # with the used dataset, we end up with 2 sequence

def test_sft_trainer_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down

0 comments on commit a84fc5d

Please sign in to comment.