Skip to content

Commit

Permalink
🔧 Use standard unittest assertion methods (#2283)
Browse files Browse the repository at this point in the history
* WIP: Partial unit test update

* Update unittest format

* Update tests/slow/test_sft_slow.py comment

* Refactor unit tests: replace pytest.raises with self.assertRaises

* Fix: Restore accidentally deleted 'ref_model' parameter in DPOTrainer

* Re-run pre-commit

* fix: Incorrectly replacing non-TestCase assert

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
ccs96307 and qgallouedec authored Oct 31, 2024
1 parent bb56c6e commit 24fb327
Show file tree
Hide file tree
Showing 20 changed files with 543 additions and 501 deletions.
8 changes: 4 additions & 4 deletions tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
peft_config=self.peft_config,
)

assert isinstance(trainer.model, PeftModel)
assert trainer.ref_model is None
self.assertIsInstance(trainer.model, PeftModel)
self.assertIsNone(trainer.ref_model)

# train the model
trainer.train()
Expand Down Expand Up @@ -212,8 +212,8 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
peft_config=self.peft_config,
)

assert isinstance(trainer.model, PeftModel)
assert trainer.ref_model is None
self.assertIsInstance(trainer.model, PeftModel)
self.assertIsNone(trainer.ref_model)

# train the model
trainer.train()
Expand Down
8 changes: 4 additions & 4 deletions tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_sft_trainer_peft(self, model_name, packing):
peft_config=self.peft_config,
)

assert isinstance(trainer.model, PeftModel)
self.assertIsInstance(trainer.model, PeftModel)

trainer.train()

Expand Down Expand Up @@ -255,7 +255,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
peft_config=self.peft_config,
)

assert isinstance(trainer.model, PeftModel)
self.assertIsInstance(trainer.model, PeftModel)

trainer.train()

Expand Down Expand Up @@ -337,7 +337,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
peft_config=self.peft_config,
)

assert isinstance(trainer.model, PeftModel)
self.assertIsInstance(trainer.model, PeftModel)

trainer.train()

Expand Down Expand Up @@ -380,7 +380,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
peft_config=self.peft_config,
)

assert isinstance(trainer.model, PeftModel)
self.assertIsInstance(trainer.model, PeftModel)

trainer.train()

Expand Down
10 changes: 5 additions & 5 deletions tests/test_alignprop_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def tearDown(self) -> None:
def test_generate_samples(self, use_lora):
trainer = self.trainer_with_lora if use_lora else self.trainer_without_lora
output_pairs = trainer._generate_samples(2, with_grad=True)
assert len(output_pairs.keys()) == 3
assert len(output_pairs["images"]) == 2
self.assertEqual(len(output_pairs.keys()), 3)
self.assertEqual(len(output_pairs["images"]), 2)

@parameterized.expand([True, False])
def test_calculate_loss(self, use_lora):
Expand All @@ -81,10 +81,10 @@ def test_calculate_loss(self, use_lora):
images = sample["images"]
prompts = sample["prompts"]

assert images.shape == (2, 3, 128, 128)
assert len(prompts) == 2
self.assertTupleEqual(images.shape, (2, 3, 128, 128))
self.assertEqual(len(prompts), 2)

rewards = trainer.compute_rewards(sample)
loss = trainer.calculate_loss(rewards)

assert torch.isfinite(loss.cpu())
self.assertTrue(torch.isfinite(loss.cpu()))
6 changes: 3 additions & 3 deletions tests/test_best_of_n_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def test_different_input_types(self):

for q, expected_length in various_queries_formats:
results = best_of_n.generate(q)
assert isinstance(results, list)
assert len(results) == expected_length
self.assertIsInstance(results, list)
self.assertEqual(len(results), expected_length)

def test_different_sample_sizes_and_n_candidates_values(self):
r"""
Expand Down Expand Up @@ -109,4 +109,4 @@ def test_different_sample_sizes_and_n_candidates_values(self):
tokenized_queries = [self.tokenizer.encode(query) for query in queries]
results = best_of_n.generate(tokenized_queries)
for result in results:
assert len(result) == expected
self.assertEqual(len(result), expected)
49 changes: 24 additions & 25 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,29 @@
import unittest


@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_sft_cli():
try:
subprocess.run(
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
except BaseException as exc:
raise AssertionError("An error occurred while running the CLI, please double check") from exc
class CLITester(unittest.TestCase):
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_sft_cli(self):
try:
subprocess.run(
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
except BaseException:
self.fail("An error occurred while running the CLI, please double check")

@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_dpo_cli(self):
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
except BaseException:
self.fail("An error occurred while running the CLI, please double check")

@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_dpo_cli():
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
except BaseException as exc:
raise AssertionError("An error occurred while running the CLI, please double check") from exc


def test_env_cli():
output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True)
assert "- Python version: " in output.stdout
def test_env_cli(self):
output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True)
self.assertIn("- Python version: ", output.stdout)
6 changes: 3 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def setUp(self):
self.test_input_unmasked = self.test_input[1:3]

def test_masked_mean(self):
assert torch.mean(self.test_input_unmasked) == masked_mean(self.test_input, self.test_mask)
self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask))

def test_masked_var(self):
assert torch.var(self.test_input_unmasked) == masked_var(self.test_input, self.test_mask)
self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))

def test_masked_whiten(self):
whiten_unmasked = whiten(self.test_input_unmasked)
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
diffs = (whiten_unmasked - whiten_masked).sum()
assert abs(diffs.item()) < 0.00001
self.assertLess(abs(diffs.item()), 0.00001)
8 changes: 4 additions & 4 deletions tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def test_cpo_trainer(self, name, loss_type, config_name):

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)
self.assertFalse(torch.equal(param, new_param))

@require_peft
@parameterized.expand(
Expand Down Expand Up @@ -142,12 +142,12 @@ def test_cpo_trainer_with_lora(self, config_name):

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)
self.assertFalse(torch.equal(param, new_param))
8 changes: 4 additions & 4 deletions tests/test_data_collator_completion_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self):
self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:]

# Plain check on string
assert self.response_template in self.instruction
self.assertIn(self.response_template, self.instruction)
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)

# Test the fix for #598
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self):
collator_output["labels"][torch.where(collator_output["labels"] != -100)]
)
expected_text = " First response\n\n Second response" ""
assert collator_text == expected_text
self.assertEqual(collator_text, expected_text)

def test_data_collator_handling_of_long_sequences(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
Expand All @@ -94,7 +94,7 @@ def test_data_collator_handling_of_long_sequences(self):
self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer)
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
assert result, "Not all values in the tensor are -100."
self.assertTrue(result, "Not all values in the tensor are -100.")

# check DataCollatorForCompletionOnlyLM using response template and instruction template
self.instruction_template = "\n### User:"
Expand All @@ -103,7 +103,7 @@ def test_data_collator_handling_of_long_sequences(self):
)
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
assert result, "Not all values in the tensor are -100."
self.assertTrue(result, "Not all values in the tensor are -100.")

def test_padding_free(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
Expand Down
60 changes: 30 additions & 30 deletions tests/test_dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,20 @@ def test_get_formatting_func_from_dataset_with_chatml_messages(self):

# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
assert isinstance(formatting_func, Callable)
self.assertIsInstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
expected = "<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"
assert formatted_text == expected
self.assertEqual(formatted_text, expected)
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == [expected]
self.assertListEqual(formatted_text, [expected])

# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
assert formatted_text == expected
self.assertEqual(formatted_text, expected)
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == [expected]
self.assertListEqual(formatted_text, [expected])

def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
dataset = Dataset.from_dict(
Expand All @@ -71,48 +71,48 @@ def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
)
# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
assert isinstance(formatting_func, Callable)
self.assertIsInstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
expected = "<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"
assert formatted_text == expected
self.assertEqual(formatted_text, expected)
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == [expected]
self.assertListEqual(formatted_text, [expected])

# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
assert formatted_text == expected
self.assertEqual(formatted_text, expected)
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == [expected]
self.assertListEqual(formatted_text, [expected])

def test_get_formatting_func_from_dataset_with_instruction(self):
dataset = Dataset.from_list(
[{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}]
)
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
assert formatting_func is not None
assert isinstance(formatting_func, Callable)
self.assertIsNotNone(formatting_func)
self.assertIsInstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
assert formatted_text == "<s>[INST] What is 2+2? [/INST] 4 </s>"
self.assertEqual(formatted_text, "<s>[INST] What is 2+2? [/INST] 4 </s>")
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == ["<s>[INST] What is 2+2? [/INST] 4 </s>"]
self.assertListEqual(formatted_text, ["<s>[INST] What is 2+2? [/INST] 4 </s>"])

def test_get_formatting_func_from_dataset_from_hub(self):
ds_1 = load_dataset("philschmid/trl-test-instruction", split="train")
ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train")
for ds in [ds_1, ds_2]:
formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer)
assert formatting_func is not None
assert isinstance(formatting_func, Callable)
self.assertIsNotNone(formatting_func)
self.assertIsInstance(formatting_func, Callable)
ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train")
formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer)
assert formatting_func is None
self.assertIsNone(formatting_func)

def test_get_formatting_func_from_dataset_with_unknown_format(self):
dataset = Dataset.from_dict({"text": "test"})
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
assert formatting_func is None
self.assertIsNone(formatting_func)


class SetupChatFormatTestCase(unittest.TestCase):
Expand All @@ -130,15 +130,15 @@ def test_setup_chat_format(self):

_chatml = ChatMlSpecialTokens()
# Check if special tokens are correctly set
assert modified_tokenizer.eos_token == "<|im_end|>"
assert modified_tokenizer.pad_token == "<|im_end|>"
assert modified_tokenizer.bos_token == "<|im_start|>"
assert modified_tokenizer.eos_token == _chatml.eos_token
assert modified_tokenizer.pad_token == _chatml.pad_token
assert modified_tokenizer.bos_token == _chatml.bos_token
assert len(modified_tokenizer) == (original_tokenizer_len + 2)
assert (self.model.get_input_embeddings().weight.shape[0] % 64) == 0
assert self.model.get_input_embeddings().weight.shape[0] == (original_tokenizer_len + 64)
self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
self.assertEqual(modified_tokenizer.pad_token, "<|im_end|>")
self.assertEqual(modified_tokenizer.bos_token, "<|im_start|>")
self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token)
self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token)
self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token)
self.assertEqual(len(modified_tokenizer), (original_tokenizer_len + 2))
self.assertEqual((self.model.get_input_embeddings().weight.shape[0] % 64), 0)
self.assertEqual(self.model.get_input_embeddings().weight.shape[0], (original_tokenizer_len + 64))

def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
Expand All @@ -152,7 +152,7 @@ def test_example_with_setup_model(self):
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)

assert (
prompt
== "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
self.assertEqual(
prompt,
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
)
Loading

0 comments on commit 24fb327

Please sign in to comment.