diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py
index 6538c56b74..8981af04d3 100644
--- a/tests/slow/test_dpo_slow.py
+++ b/tests/slow/test_dpo_slow.py
@@ -148,8 +148,8 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
peft_config=self.peft_config,
)
- assert isinstance(trainer.model, PeftModel)
- assert trainer.ref_model is None
+ self.assertIsInstance(trainer.model, PeftModel)
+ self.assertIsNone(trainer.ref_model)
# train the model
trainer.train()
@@ -212,8 +212,8 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
peft_config=self.peft_config,
)
- assert isinstance(trainer.model, PeftModel)
- assert trainer.ref_model is None
+ self.assertIsInstance(trainer.model, PeftModel)
+ self.assertIsNone(trainer.ref_model)
# train the model
trainer.train()
diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py
index 5035de4848..1d8989b8b3 100644
--- a/tests/slow/test_sft_slow.py
+++ b/tests/slow/test_sft_slow.py
@@ -148,7 +148,7 @@ def test_sft_trainer_peft(self, model_name, packing):
peft_config=self.peft_config,
)
- assert isinstance(trainer.model, PeftModel)
+ self.assertIsInstance(trainer.model, PeftModel)
trainer.train()
@@ -255,7 +255,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
peft_config=self.peft_config,
)
- assert isinstance(trainer.model, PeftModel)
+ self.assertIsInstance(trainer.model, PeftModel)
trainer.train()
@@ -337,7 +337,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
peft_config=self.peft_config,
)
- assert isinstance(trainer.model, PeftModel)
+ self.assertIsInstance(trainer.model, PeftModel)
trainer.train()
@@ -380,7 +380,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
peft_config=self.peft_config,
)
- assert isinstance(trainer.model, PeftModel)
+ self.assertIsInstance(trainer.model, PeftModel)
trainer.train()
diff --git a/tests/test_alignprop_trainer.py b/tests/test_alignprop_trainer.py
index 995b91a750..a7446639d1 100644
--- a/tests/test_alignprop_trainer.py
+++ b/tests/test_alignprop_trainer.py
@@ -70,8 +70,8 @@ def tearDown(self) -> None:
def test_generate_samples(self, use_lora):
trainer = self.trainer_with_lora if use_lora else self.trainer_without_lora
output_pairs = trainer._generate_samples(2, with_grad=True)
- assert len(output_pairs.keys()) == 3
- assert len(output_pairs["images"]) == 2
+ self.assertEqual(len(output_pairs.keys()), 3)
+ self.assertEqual(len(output_pairs["images"]), 2)
@parameterized.expand([True, False])
def test_calculate_loss(self, use_lora):
@@ -81,10 +81,10 @@ def test_calculate_loss(self, use_lora):
images = sample["images"]
prompts = sample["prompts"]
- assert images.shape == (2, 3, 128, 128)
- assert len(prompts) == 2
+ self.assertTupleEqual(images.shape, (2, 3, 128, 128))
+ self.assertEqual(len(prompts), 2)
rewards = trainer.compute_rewards(sample)
loss = trainer.calculate_loss(rewards)
- assert torch.isfinite(loss.cpu())
+ self.assertTrue(torch.isfinite(loss.cpu()))
diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py
index 76ce85ae0a..584c917a04 100644
--- a/tests/test_best_of_n_sampler.py
+++ b/tests/test_best_of_n_sampler.py
@@ -73,8 +73,8 @@ def test_different_input_types(self):
for q, expected_length in various_queries_formats:
results = best_of_n.generate(q)
- assert isinstance(results, list)
- assert len(results) == expected_length
+ self.assertIsInstance(results, list)
+ self.assertEqual(len(results), expected_length)
def test_different_sample_sizes_and_n_candidates_values(self):
r"""
@@ -109,4 +109,4 @@ def test_different_sample_sizes_and_n_candidates_values(self):
tokenized_queries = [self.tokenizer.encode(query) for query in queries]
results = best_of_n.generate(tokenized_queries)
for result in results:
- assert len(result) == expected
+ self.assertEqual(len(result), expected)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 2fc0e6eb03..0a34995ecd 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -16,30 +16,29 @@
import unittest
-@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
-def test_sft_cli():
- try:
- subprocess.run(
- "trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
- shell=True,
- check=True,
- )
- except BaseException as exc:
- raise AssertionError("An error occurred while running the CLI, please double check") from exc
+class CLITester(unittest.TestCase):
+ @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
+ def test_sft_cli(self):
+ try:
+ subprocess.run(
+ "trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name stanfordnlp/imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
+ shell=True,
+ check=True,
+ )
+ except BaseException:
+ self.fail("An error occurred while running the CLI, please double check")
+ @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
+ def test_dpo_cli(self):
+ try:
+ subprocess.run(
+ "trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
+ shell=True,
+ check=True,
+ )
+ except BaseException:
+ self.fail("An error occurred while running the CLI, please double check")
-@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
-def test_dpo_cli():
- try:
- subprocess.run(
- "trl dpo --max_steps 1 --output_dir tmp-dpo --model_name_or_path trl-internal-testing/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/tiny-ultrafeedback-binarized --learning_rate 1e-4 --lr_scheduler_type cosine",
- shell=True,
- check=True,
- )
- except BaseException as exc:
- raise AssertionError("An error occurred while running the CLI, please double check") from exc
-
-
-def test_env_cli():
- output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True)
- assert "- Python version: " in output.stdout
+ def test_env_cli(self):
+ output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True)
+ self.assertIn("- Python version: ", output.stdout)
diff --git a/tests/test_core.py b/tests/test_core.py
index 34fe81064f..2d8531d591 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -29,13 +29,13 @@ def setUp(self):
self.test_input_unmasked = self.test_input[1:3]
def test_masked_mean(self):
- assert torch.mean(self.test_input_unmasked) == masked_mean(self.test_input, self.test_mask)
+ self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask))
def test_masked_var(self):
- assert torch.var(self.test_input_unmasked) == masked_var(self.test_input, self.test_mask)
+ self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))
def test_masked_whiten(self):
whiten_unmasked = whiten(self.test_input_unmasked)
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
diffs = (whiten_unmasked - whiten_masked).sum()
- assert abs(diffs.item()) < 0.00001
+ self.assertLess(abs(diffs.item()), 0.00001)
diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py
index 7bbccbb248..38a0f695fd 100644
--- a/tests/test_cpo_trainer.py
+++ b/tests/test_cpo_trainer.py
@@ -84,14 +84,14 @@ def test_cpo_trainer(self, name, loss_type, config_name):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.equal(param, new_param)
+ self.assertFalse(torch.equal(param, new_param))
@require_peft
@parameterized.expand(
@@ -142,7 +142,7 @@ def test_cpo_trainer_with_lora(self, config_name):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
@@ -150,4 +150,4 @@ def test_cpo_trainer_with_lora(self, config_name):
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.equal(param, new_param)
+ self.assertFalse(torch.equal(param, new_param))
diff --git a/tests/test_data_collator_completion_only.py b/tests/test_data_collator_completion_only.py
index b5bcc1ed32..4ae0c7f3a6 100644
--- a/tests/test_data_collator_completion_only.py
+++ b/tests/test_data_collator_completion_only.py
@@ -45,7 +45,7 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self):
self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:]
# Plain check on string
- assert self.response_template in self.instruction
+ self.assertIn(self.response_template, self.instruction)
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
# Test the fix for #598
@@ -80,7 +80,7 @@ def test_data_collator_finds_response_template_llama2_tokenizer(self):
collator_output["labels"][torch.where(collator_output["labels"] != -100)]
)
expected_text = " First response\n\n Second response" ""
- assert collator_text == expected_text
+ self.assertEqual(collator_text, expected_text)
def test_data_collator_handling_of_long_sequences(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
@@ -94,7 +94,7 @@ def test_data_collator_handling_of_long_sequences(self):
self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer)
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
- assert result, "Not all values in the tensor are -100."
+ self.assertTrue(result, "Not all values in the tensor are -100.")
# check DataCollatorForCompletionOnlyLM using response template and instruction template
self.instruction_template = "\n### User:"
@@ -103,7 +103,7 @@ def test_data_collator_handling_of_long_sequences(self):
)
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
- assert result, "Not all values in the tensor are -100."
+ self.assertTrue(result, "Not all values in the tensor are -100.")
def test_padding_free(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py
index f1e9bcb4d8..15f6c63a67 100644
--- a/tests/test_dataset_formatting.py
+++ b/tests/test_dataset_formatting.py
@@ -42,20 +42,20 @@ def test_get_formatting_func_from_dataset_with_chatml_messages(self):
# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
- assert isinstance(formatting_func, Callable)
+ self.assertIsInstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
expected = "[INST] <>\nYou are helpful\n<>\n\nHello [/INST] Hi, how can I help you? "
- assert formatted_text == expected
+ self.assertEqual(formatted_text, expected)
formatted_text = formatting_func(dataset[0:1])
- assert formatted_text == [expected]
+ self.assertListEqual(formatted_text, [expected])
# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
- assert formatted_text == expected
+ self.assertEqual(formatted_text, expected)
formatted_text = formatting_func(dataset[0:1])
- assert formatted_text == [expected]
+ self.assertListEqual(formatted_text, [expected])
def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
dataset = Dataset.from_dict(
@@ -71,48 +71,48 @@ def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
)
# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
- assert isinstance(formatting_func, Callable)
+ self.assertIsInstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
expected = "[INST] <>\nYou are helpful\n<>\n\nHello [/INST] Hi, how can I help you? "
- assert formatted_text == expected
+ self.assertEqual(formatted_text, expected)
formatted_text = formatting_func(dataset[0:1])
- assert formatted_text == [expected]
+ self.assertListEqual(formatted_text, [expected])
# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
- assert formatted_text == expected
+ self.assertEqual(formatted_text, expected)
formatted_text = formatting_func(dataset[0:1])
- assert formatted_text == [expected]
+ self.assertListEqual(formatted_text, [expected])
def test_get_formatting_func_from_dataset_with_instruction(self):
dataset = Dataset.from_list(
[{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}]
)
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
- assert formatting_func is not None
- assert isinstance(formatting_func, Callable)
+ self.assertIsNotNone(formatting_func)
+ self.assertIsInstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
- assert formatted_text == "[INST] What is 2+2? [/INST] 4 "
+ self.assertEqual(formatted_text, "[INST] What is 2+2? [/INST] 4 ")
formatted_text = formatting_func(dataset[0:1])
- assert formatted_text == ["[INST] What is 2+2? [/INST] 4 "]
+ self.assertListEqual(formatted_text, ["[INST] What is 2+2? [/INST] 4 "])
def test_get_formatting_func_from_dataset_from_hub(self):
ds_1 = load_dataset("philschmid/trl-test-instruction", split="train")
ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train")
for ds in [ds_1, ds_2]:
formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer)
- assert formatting_func is not None
- assert isinstance(formatting_func, Callable)
+ self.assertIsNotNone(formatting_func)
+ self.assertIsInstance(formatting_func, Callable)
ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train")
formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer)
- assert formatting_func is None
+ self.assertIsNone(formatting_func)
def test_get_formatting_func_from_dataset_with_unknown_format(self):
dataset = Dataset.from_dict({"text": "test"})
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
- assert formatting_func is None
+ self.assertIsNone(formatting_func)
class SetupChatFormatTestCase(unittest.TestCase):
@@ -130,15 +130,15 @@ def test_setup_chat_format(self):
_chatml = ChatMlSpecialTokens()
# Check if special tokens are correctly set
- assert modified_tokenizer.eos_token == "<|im_end|>"
- assert modified_tokenizer.pad_token == "<|im_end|>"
- assert modified_tokenizer.bos_token == "<|im_start|>"
- assert modified_tokenizer.eos_token == _chatml.eos_token
- assert modified_tokenizer.pad_token == _chatml.pad_token
- assert modified_tokenizer.bos_token == _chatml.bos_token
- assert len(modified_tokenizer) == (original_tokenizer_len + 2)
- assert (self.model.get_input_embeddings().weight.shape[0] % 64) == 0
- assert self.model.get_input_embeddings().weight.shape[0] == (original_tokenizer_len + 64)
+ self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
+ self.assertEqual(modified_tokenizer.pad_token, "<|im_end|>")
+ self.assertEqual(modified_tokenizer.bos_token, "<|im_start|>")
+ self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token)
+ self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token)
+ self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token)
+ self.assertEqual(len(modified_tokenizer), (original_tokenizer_len + 2))
+ self.assertEqual((self.model.get_input_embeddings().weight.shape[0] % 64), 0)
+ self.assertEqual(self.model.get_input_embeddings().weight.shape[0], (original_tokenizer_len + 64))
def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
@@ -152,7 +152,7 @@ def test_example_with_setup_model(self):
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
- assert (
- prompt
- == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
+ self.assertEqual(
+ prompt,
+ "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
)
diff --git a/tests/test_ddpo_trainer.py b/tests/test_ddpo_trainer.py
index 8db0b0747a..65a626589a 100644
--- a/tests/test_ddpo_trainer.py
+++ b/tests/test_ddpo_trainer.py
@@ -69,13 +69,13 @@ def test_loss(self):
clip_range = 0.0001
ratio = torch.tensor([1.0])
loss = self.trainer.loss(advantage, clip_range, ratio)
- assert loss.item() == 1.0
+ self.assertEqual(loss.item(), 1.0)
def test_generate_samples(self):
samples, output_pairs = self.trainer._generate_samples(1, 2)
- assert len(samples) == 1
- assert len(output_pairs) == 1
- assert len(output_pairs[0][0]) == 2
+ self.assertEqual(len(samples), 1)
+ self.assertEqual(len(output_pairs), 1)
+ self.assertEqual(len(output_pairs[0][0]), 2)
def test_calculate_loss(self):
samples, _ = self.trainer._generate_samples(1, 2)
@@ -88,16 +88,16 @@ def test_calculate_loss(self):
prompt_embeds = sample["prompt_embeds"]
advantage = torch.tensor([1.0], device=prompt_embeds.device)
- assert latents.shape == (1, 4, 64, 64)
- assert next_latents.shape == (1, 4, 64, 64)
- assert log_probs.shape == (1,)
- assert timesteps.shape == (1,)
- assert prompt_embeds.shape == (2, 77, 32)
+ self.assertTupleEqual(latents.shape, (1, 4, 64, 64))
+ self.assertTupleEqual(next_latents.shape, (1, 4, 64, 64))
+ self.assertTupleEqual(log_probs.shape, (1,))
+ self.assertTupleEqual(timesteps.shape, (1,))
+ self.assertTupleEqual(prompt_embeds.shape, (2, 77, 32))
loss, approx_kl, clipfrac = self.trainer.calculate_loss(
latents, timesteps, next_latents, log_probs, advantage, prompt_embeds
)
- assert torch.isfinite(loss.cpu())
+ self.assertTrue(torch.isfinite(loss.cpu()))
@require_diffusers
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index 156e422a85..3983da8d46 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -17,7 +17,6 @@
from unittest.mock import MagicMock
import numpy as np
-import pytest
import torch
from datasets import Dataset, features, load_dataset
from parameterized import parameterized
@@ -234,14 +233,14 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
+ self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
def test_dpo_trainer_with_weighting(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -275,14 +274,14 @@ def test_dpo_trainer_with_weighting(self):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
+ self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
@parameterized.expand(
[
@@ -321,14 +320,14 @@ def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha, _):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.equal(param, new_param)
+ self.assertFalse(torch.equal(param, new_param))
def test_dpo_trainer_with_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -392,7 +391,7 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
@@ -400,7 +399,7 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.equal(param, new_param)
+ self.assertFalse(torch.equal(param, new_param))
def test_dpo_trainer_padding_token_is_none(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -508,14 +507,14 @@ def test_tr_dpo_trainer(self):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.ref_model.get_parameter(n)
# check the ref model's params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.equal(param, new_param)
+ self.assertFalse(torch.equal(param, new_param))
@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
@@ -598,7 +597,6 @@ def test_dpo_lora_save(self):
# save peft adapter
trainer.save_model()
- # assert that the model is loaded without giving OSError
try:
AutoModelForCausalLM.from_pretrained(tmp_dir)
except OSError:
@@ -915,8 +913,8 @@ def test_dpo_trainer_torch_dtype(self):
args=training_args,
train_dataset=dummy_dataset["train"],
)
- assert trainer.model.config.torch_dtype == torch.float16
- assert trainer.ref_model.config.torch_dtype == torch.float16
+ self.assertEqual(trainer.model.config.torch_dtype, torch.float16)
+ self.assertEqual(trainer.ref_model.config.torch_dtype, torch.float16)
# Now test when `torch_dtype` is provided but is wrong to either the model or the ref_model
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -928,10 +926,7 @@ def test_dpo_trainer_torch_dtype(self):
report_to="none",
)
- with pytest.raises(
- ValueError,
- match="Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
- ):
+ with self.assertRaises(ValueError) as context:
_ = DPOTrainer(
model=self.model_id,
processing_class=self.tokenizer,
@@ -939,6 +934,11 @@ def test_dpo_trainer_torch_dtype(self):
train_dataset=dummy_dataset["train"],
)
+ self.assertIn(
+ "Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
+ str(context.exception),
+ )
+
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
@@ -948,10 +948,7 @@ def test_dpo_trainer_torch_dtype(self):
report_to="none",
)
- with pytest.raises(
- ValueError,
- match="Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
- ):
+ with self.assertRaises(ValueError) as context:
_ = DPOTrainer(
model=self.model_id,
ref_model=self.model_id,
@@ -960,6 +957,11 @@ def test_dpo_trainer_torch_dtype(self):
train_dataset=dummy_dataset["train"],
)
+ self.assertIn(
+ "Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
+ str(context.exception),
+ )
+
def test_dpo_loss_alpha_div_f(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -1001,7 +1003,7 @@ def test_dpo_loss_alpha_div_f(self):
losses, _, _ = trainer.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
- assert torch.isfinite(losses).cpu().numpy().all()
+ self.assertTrue(torch.isfinite(losses).cpu().numpy().all())
def test_dpo_loss_js_div_f(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
@@ -1044,7 +1046,7 @@ def test_dpo_loss_js_div_f(self):
losses, _, _ = trainer.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
- assert torch.isfinite(losses).cpu().numpy().all()
+ self.assertTrue(torch.isfinite(losses).cpu().numpy().all())
class DPOVisionTrainerTester(unittest.TestCase):
@@ -1119,7 +1121,7 @@ def test_vdpo_trainer(self, model_id):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the trainable params have changed
for n, param in previous_trainable_params.items():
@@ -1132,7 +1134,7 @@ def test_vdpo_trainer(self, model_id):
# For some reason, these params are not updated. This is probably not related to TRL, but to
# the model itself. We should investigate this further, but for now we just skip these params.
continue
- assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)
+ self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
if __name__ == "__main__":
diff --git a/tests/test_environments.py b/tests/test_environments.py
index b24aa83952..d2ee42f040 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -38,12 +38,12 @@ def test_text_history_init(self):
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
- assert history.text == text
- assert torch.equal(history.tokens, tokens)
- assert torch.equal(history.token_masks, torch.zeros_like(tokens))
+ self.assertEqual(history.text, text)
+ self.assertTrue(torch.equal(history.tokens, tokens))
+ self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens)))
history = TextHistory(text, tokens, system=False)
- assert torch.equal(history.token_masks, torch.ones_like(tokens))
+ self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens)))
def test_text_history_append_segment(self):
text = "Hello there!"
@@ -51,26 +51,26 @@ def test_text_history_append_segment(self):
history = TextHistory(text, tokens)
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False)
- assert history.text == (text + "General Kenobi!")
- assert torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6]))
- assert torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1]))
+ self.assertEqual(history.text, (text + "General Kenobi!"))
+ self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6])))
+ self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1])))
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
- assert history.text == ((text + "General Kenobi!") + "You are a bold one!")
- assert torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]))
- assert torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0]))
+ self.assertEqual(history.text, ((text + "General Kenobi!") + "You are a bold one!"))
+ self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])))
+ self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0])))
def test_text_history_complete(self):
text = "Hello there!"
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
history.complete()
- assert history.completed
- assert not history.truncated
+ self.assertTrue(history.completed)
+ self.assertFalse(history.truncated)
history.complete(truncated=True)
- assert history.completed
- assert history.truncated
+ self.assertTrue(history.completed)
+ self.assertTrue(history.truncated)
def test_text_history_last_segment(self):
text = "Hello there!"
@@ -78,7 +78,7 @@ def test_text_history_last_segment(self):
history = TextHistory(text, tokens)
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]))
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
- assert history.last_text_segment == "You are a bold one!"
+ self.assertEqual(history.last_text_segment, "You are a bold one!")
def test_text_history_split_query_response(self):
text = "Hello there!"
@@ -88,9 +88,9 @@ def test_text_history_split_query_response(self):
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True)
query, response, mask = history.split_query_response_tokens()
- assert torch.equal(query, torch.tensor([1, 2, 3]))
- assert torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9]))
- assert torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))
+ self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3])))
+ self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9])))
+ self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0])))
class TextEnvironmentTester(unittest.TestCase):
@@ -111,10 +111,10 @@ def test_text_environment_setup(self):
reward_fn=lambda x: torch.tensor(1),
prompt="I am a prompt!\n",
)
- assert env.prompt == "I am a prompt!\n"
- assert list(env.tools.keys()) == ["DummyTool"]
- assert isinstance(env.tools["DummyTool"], DummyTool)
- assert env.reward_fn("Hello there!") == 1
+ self.assertEqual(env.prompt, "I am a prompt!\n")
+ self.assertListEqual(list(env.tools.keys()), ["DummyTool"])
+ self.assertIsInstance(env.tools["DummyTool"], DummyTool)
+ self.assertEqual(env.reward_fn("Hello there!"), 1)
def test_text_environment_generate(self):
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id}
@@ -137,7 +137,7 @@ def test_text_environment_generate(self):
generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs]
generations_single = self.gpt2_tokenizer.batch_decode(generations_single)
- assert generations_single == generations_batched
+ self.assertEqual(generations_single, generations_batched)
def test_text_environment_tool_call_parsing(self):
string_valid = "Something something Hello there!"
@@ -154,24 +154,24 @@ def test_text_environment_tool_call_parsing(self):
prompt="I am a prompt!\n",
)
tool, response = env.parse_tool_call(string_valid)
- assert tool == "Tool1"
- assert response == "Hello there!"
+ self.assertEqual(tool, "Tool1")
+ self.assertEqual(response, "Hello there!")
tool, response = env.parse_tool_call(string_invalid_request)
- assert tool is None
- assert response is None
+ self.assertIsNone(tool)
+ self.assertIsNone(response)
tool, response = env.parse_tool_call(string_invalid_call)
- assert tool is None
- assert response is None
+ self.assertIsNone(tool)
+ self.assertIsNone(response)
tool, response = env.parse_tool_call(string_invalid_tool)
- assert tool is None
- assert response is None
+ self.assertIsNone(tool)
+ self.assertIsNone(response)
tool, response = env.parse_tool_call(string_invalid_random)
- assert tool is None
- assert response is None
+ self.assertIsNone(tool)
+ self.assertIsNone(response)
def test_text_environment_tool_truncation(self):
env = TextEnvironment(
@@ -184,19 +184,19 @@ def test_text_environment_tool_truncation(self):
env.max_tool_response = 100
history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3])))
- assert (len(history.last_text_segment) - len(env.response_token)) == 100
+ self.assertEqual((len(history.last_text_segment) - len(env.response_token)), 100)
env.max_tool_response = 500
history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3])))
- assert (len(history.last_text_segment) - len(env.response_token)) == 500
+ self.assertEqual((len(history.last_text_segment) - len(env.response_token)), 500)
env.max_tool_response = 1001
history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3])))
- assert (len(history.last_text_segment) - len(env.response_token)) == 1000
+ self.assertEqual((len(history.last_text_segment) - len(env.response_token)), 1000)
env.max_tool_response = 2000
history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3])))
- assert (len(history.last_text_segment) - len(env.response_token)) == 1000
+ self.assertEqual((len(history.last_text_segment) - len(env.response_token)), 1000)
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
def test_text_environment_max_calls(self, mock_generate):
@@ -210,20 +210,23 @@ def test_text_environment_max_calls(self, mock_generate):
env.max_turns = 1
_, _, _, _, histories = env.run(["test"])
- assert histories[0].text == (
- ("I am a prompt!\n" + "test") + (1 * "testtest")
+ self.assertEqual(
+ histories[0].text,
+ ("I am a prompt!\n" + "test") + (1 * "testtest"),
)
env.max_turns = 2
_, _, _, _, histories = env.run(["test"])
- assert histories[0].text == (
- ("I am a prompt!\n" + "test") + (2 * "testtest")
+ self.assertEqual(
+ histories[0].text,
+ ("I am a prompt!\n" + "test") + (2 * "testtest"),
)
env.max_turns = 4
_, _, _, _, histories = env.run(["test"])
- assert histories[0].text == (
- ("I am a prompt!\n" + "test") + (4 * "testtest")
+ self.assertEqual(
+ histories[0].text,
+ ("I am a prompt!\n" + "test") + (4 * "testtest"),
)
def test_text_environment_compute_rewards(self):
@@ -239,7 +242,7 @@ def test_text_environment_compute_rewards(self):
histories = env.compute_reward(histories)
for i in range(8):
- assert histories[i].reward == i
+ self.assertEqual(histories[i].reward, i)
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
def test_text_environment_run(self, mock_generate):
@@ -255,20 +258,21 @@ def test_text_environment_run(self, mock_generate):
task_2 = "Hello there! General Kenobi!"
query, response, response_mask, reward, histories = env.run([task_1, task_2])
- assert len(query[0]) == 9
- assert len(query[1]) == 12
- assert len(response[0]) == 14
- assert len(response[1]) == 14
- assert response_mask[0].sum() == (2 * 3)
+ self.assertEqual(len(query[0]), 9)
+ self.assertEqual(len(query[1]), 12)
+ self.assertEqual(len(response[0]), 14)
+ self.assertEqual(len(response[1]), 14)
+ self.assertEqual(response_mask[0].sum(), (2 * 3))
# mocked generate always adds 3 toknes
- assert response_mask[1].sum() == (2 * 3)
+ self.assertEqual(response_mask[1].sum(), (2 * 3))
# mocked generate always adds 3 toknes
- assert reward[0] == 0
- assert reward[1] == 1
- assert histories[0].text == (
- ("I am a prompt!\n" + "Hello there!") + (2 * "testtest")
+ self.assertEqual(reward[1], 1)
+ self.assertEqual(
+ histories[0].text,
+ ("I am a prompt!\n" + "Hello there!") + (2 * "testtest"),
)
- assert histories[1].text == (
+ self.assertEqual(
+ histories[1].text,
("I am a prompt!\n" + "Hello there! General Kenobi!")
- + (2 * "testtest")
+ + (2 * "testtest"),
)
diff --git a/tests/test_iterative_sft_trainer.py b/tests/test_iterative_sft_trainer.py
index da3978466c..6248a2b90a 100644
--- a/tests/test_iterative_sft_trainer.py
+++ b/tests/test_iterative_sft_trainer.py
@@ -113,4 +113,4 @@ def test_iterative_step_from_tensor(self, model_name, input_name):
iterative_trainer.step(**inputs)
for param in iterative_trainer.model.parameters():
- assert param.grad is not None
+ self.assertIsNotNone(param.grad)
diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py
index a1d4503540..5298e133af 100644
--- a/tests/test_modeling_value_head.py
+++ b/tests/test_modeling_value_head.py
@@ -16,7 +16,6 @@
import tempfile
import unittest
-import pytest
import torch
from parameterized import parameterized
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig
@@ -60,136 +59,152 @@
]
-class VHeadModelTester:
- all_model_names = None
- trl_model_class = None
- transformers_model_class = None
-
- def test_value_head(self):
- r"""
- Test if the v-head is added to the model successfully
- """
- for model_name in self.all_model_names:
- model = self.trl_model_class.from_pretrained(model_name)
- assert hasattr(model, "v_head")
-
- def test_value_head_shape(self):
- r"""
- Test if the v-head has the correct shape
- """
- for model_name in self.all_model_names:
- model = self.trl_model_class.from_pretrained(model_name)
- assert model.v_head.summary.weight.shape[0] == 1
-
- def test_value_head_init_random(self):
- r"""
- Test if the v-head has been randomly initialized.
- We can check that by making sure the bias is different
- than zeros by default.
- """
- for model_name in self.all_model_names:
- model = self.trl_model_class.from_pretrained(model_name)
- assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
-
- def test_value_head_not_str(self):
- r"""
- Test if the v-head is added to the model successfully, by passing a non `PretrainedModel`
- as an argument to `from_pretrained`.
- """
- for model_name in self.all_model_names:
- pretrained_model = self.transformers_model_class.from_pretrained(model_name)
- model = self.trl_model_class.from_pretrained(pretrained_model)
- assert hasattr(model, "v_head")
-
- @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
- def test_from_save_trl(self):
- """
- Test if the model can be saved and loaded from a directory and get the same weights
- Including the additional modules (e.g. v_head)
- """
- for model_name in self.all_model_names:
- model = self.trl_model_class.from_pretrained(model_name)
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- model.save_pretrained(tmp_dir)
-
- model_from_save = self.trl_model_class.from_pretrained(tmp_dir)
-
- # Check if the weights are the same
- for key in model_from_save.state_dict():
- assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])
-
- @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
- def test_from_save_trl_sharded(self):
- """
- Test if the model can be saved and loaded from a directory and get the same weights - sharded case
- """
- for model_name in self.all_model_names:
- model = self.trl_model_class.from_pretrained(model_name)
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- model.save_pretrained(tmp_dir)
-
- model_from_save = self.trl_model_class.from_pretrained(tmp_dir)
-
- # Check if the weights are the same
- for key in model_from_save.state_dict():
- assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])
-
- @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
- def test_from_save_transformers_sharded(self):
- """
- Test if the model can be saved and loaded using transformers and get the same weights - sharded case
- """
- for model_name in self.all_model_names:
- transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
-
- trl_model = self.trl_model_class.from_pretrained(model_name)
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- trl_model.save_pretrained(tmp_dir, max_shard_size="1MB")
- transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(tmp_dir)
-
- # Check if the weights are the same
- for key in transformers_model.state_dict():
- assert torch.allclose(
- transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
+class BaseTester:
+ class VHeadModelTester(unittest.TestCase):
+ all_model_names = None
+ trl_model_class = None
+ transformers_model_class = None
+
+ def test_value_head(self):
+ r"""
+ Test if the v-head is added to the model successfully
+ """
+ for model_name in self.all_model_names:
+ model = self.trl_model_class.from_pretrained(model_name)
+ self.assertTrue(hasattr(model, "v_head"))
+
+ def test_value_head_shape(self):
+ r"""
+ Test if the v-head has the correct shape
+ """
+ for model_name in self.all_model_names:
+ model = self.trl_model_class.from_pretrained(model_name)
+ self.assertEqual(model.v_head.summary.weight.shape[0], 1)
+
+ def test_value_head_init_random(self):
+ r"""
+ Test if the v-head has been randomly initialized.
+ We can check that by making sure the bias is different
+ than zeros by default.
+ """
+ for model_name in self.all_model_names:
+ model = self.trl_model_class.from_pretrained(model_name)
+ self.assertFalse(
+ torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
)
- @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
- def test_from_save_transformers(self):
- """
- Test if the model can be saved and loaded using transformers and get the same weights.
- We override the test of the super class to check if the weights are the same.
- """
- for model_name in self.all_model_names:
- transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
-
- trl_model = self.trl_model_class.from_pretrained(model_name)
-
- with tempfile.TemporaryDirectory() as tmp_dir:
- trl_model.save_pretrained(tmp_dir)
- transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(tmp_dir)
-
- # Check if the weights are the same
- for key in transformers_model.state_dict():
- assert torch.allclose(
- transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
+ def test_value_head_not_str(self):
+ r"""
+ Test if the v-head is added to the model successfully, by passing a non `PretrainedModel`
+ as an argument to `from_pretrained`.
+ """
+ for model_name in self.all_model_names:
+ pretrained_model = self.transformers_model_class.from_pretrained(model_name)
+ model = self.trl_model_class.from_pretrained(pretrained_model)
+ self.assertTrue(hasattr(model, "v_head"))
+
+ @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
+ def test_from_save_trl(self):
+ """
+ Test if the model can be saved and loaded from a directory and get the same weights
+ Including the additional modules (e.g. v_head)
+ """
+ for model_name in self.all_model_names:
+ model = self.trl_model_class.from_pretrained(model_name)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save_pretrained(tmp_dir)
+
+ model_from_save = self.trl_model_class.from_pretrained(tmp_dir)
+
+ # Check if the weights are the same
+ for key in model_from_save.state_dict():
+ self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]))
+
+ @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
+ def test_from_save_trl_sharded(self):
+ """
+ Test if the model can be saved and loaded from a directory and get the same weights - sharded case
+ """
+ for model_name in self.all_model_names:
+ model = self.trl_model_class.from_pretrained(model_name)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save_pretrained(tmp_dir)
+
+ model_from_save = self.trl_model_class.from_pretrained(tmp_dir)
+
+ # Check if the weights are the same
+ for key in model_from_save.state_dict():
+ self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]))
+
+ @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
+ def test_from_save_transformers_sharded(self):
+ """
+ Test if the model can be saved and loaded using transformers and get the same weights - sharded case
+ """
+ for model_name in self.all_model_names:
+ transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
+
+ trl_model = self.trl_model_class.from_pretrained(model_name)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ trl_model.save_pretrained(tmp_dir, max_shard_size="1MB")
+ transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(
+ tmp_dir
+ )
+
+ # Check if the weights are the same
+ for key in transformers_model.state_dict():
+ self.assertTrue(
+ torch.allclose(
+ transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
+ )
+ )
+
+ @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
+ def test_from_save_transformers(self):
+ """
+ Test if the model can be saved and loaded using transformers and get the same weights.
+ We override the test of the super class to check if the weights are the same.
+ """
+ for model_name in self.all_model_names:
+ transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
+
+ trl_model = self.trl_model_class.from_pretrained(model_name)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ trl_model.save_pretrained(tmp_dir)
+ transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained(
+ tmp_dir
+ )
+
+ # Check if the weights are the same
+ for key in transformers_model.state_dict():
+ self.assertTrue(
+ torch.allclose(
+ transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
+ )
+ )
+
+ # Check if the trl model has the same keys as the transformers model
+ # except the v_head
+ for key in trl_model.state_dict():
+ if "v_head" not in key:
+ self.assertIn(key, transformers_model.state_dict())
+ # check if the weights are the same
+ self.assertTrue(
+ torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])
+ )
+
+ # check if they have the same modules
+ self.assertEqual(
+ set(transformers_model_from_save.state_dict().keys()),
+ set(transformers_model.state_dict().keys()),
)
- # Check if the trl model has the same keys as the transformers model
- # except the v_head
- for key in trl_model.state_dict():
- if "v_head" not in key:
- assert key in transformers_model.state_dict()
- # check if the weights are the same
- assert torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])
-
- # check if they have the same modules
- assert set(transformers_model_from_save.state_dict().keys()) == set(transformers_model.state_dict().keys())
-
-class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
+class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, unittest.TestCase):
"""
Testing suite for v-head models.
"""
@@ -216,7 +231,7 @@ def test_inference(self):
# Check if the outputs are of the right size - here
# we always output 3 values - logits, loss, and value states
- assert len(outputs) == EXPECTED_OUTPUT_SIZE
+ self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE)
def test_dropout_config(self):
r"""
@@ -229,7 +244,7 @@ def test_dropout_config(self):
model = self.trl_model_class.from_pretrained(pretrained_model)
# Check if v head of the model has the same dropout as the config
- assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob
+ self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob)
def test_dropout_kwargs(self):
r"""
@@ -242,12 +257,12 @@ def test_dropout_kwargs(self):
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
# Check if v head of the model has the same dropout as the config
- assert model.v_head.dropout.p == 0.5
+ self.assertEqual(model.v_head.dropout.p, 0.5)
model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5)
# Check if v head of the model has the same dropout as the config
- assert model.v_head.dropout.p == 0.5
+ self.assertEqual(model.v_head.dropout.p, 0.5)
@parameterized.expand(ALL_CAUSAL_LM_MODELS)
def test_generate(self, model_name):
@@ -265,7 +280,7 @@ def test_raise_error_not_causallm(self):
# Test with a model without a LM head
model_id = "trl-internal-testing/tiny-random-GPT2Model"
# This should raise a ValueError
- with pytest.raises(ValueError):
+ with self.assertRaises(ValueError):
pretrained_model = AutoModelForCausalLM.from_pretrained(model_id)
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer)
@@ -281,11 +296,13 @@ def test_transformers_bf16_kwargs(self):
lm_head_namings = self.trl_model_class.lm_head_namings
- assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
+ self.assertTrue(
+ any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
+ )
for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
- assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16
+ self.assertEqual(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype, torch.bfloat16)
dummy_input = torch.LongTensor([[0, 1, 0, 1]])
@@ -303,15 +320,16 @@ def test_push_to_hub(self):
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo")
# check all keys
- assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
+ self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys())
for name, param in model.state_dict().items():
- assert torch.allclose(
- param, model_from_pretrained.state_dict()[name]
- ), f"Parameter {name} is not the same after push_to_hub and from_pretrained"
+ self.assertTrue(
+ torch.allclose(param, model_from_pretrained.state_dict()[name]),
+ f"Parameter {name} is not the same after push_to_hub and from_pretrained",
+ )
-class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
+class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, unittest.TestCase):
"""
Testing suite for v-head models.
"""
@@ -339,7 +357,7 @@ def test_inference(self):
# Check if the outputs are of the right size - here
# we always output 3 values - logits, loss, and value states
- assert len(outputs) == EXPECTED_OUTPUT_SIZE
+ self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE)
def test_dropout_config(self):
r"""
@@ -352,7 +370,7 @@ def test_dropout_config(self):
model = self.trl_model_class.from_pretrained(pretrained_model)
# Check if v head of the model has the same dropout as the config
- assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob
+ self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob)
def test_dropout_kwargs(self):
r"""
@@ -365,12 +383,12 @@ def test_dropout_kwargs(self):
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
# Check if v head of the model has the same dropout as the config
- assert model.v_head.dropout.p == 0.5
+ self.assertEqual(model.v_head.dropout.p, 0.5)
model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5)
# Check if v head of the model has the same dropout as the config
- assert model.v_head.dropout.p == 0.5
+ self.assertEqual(model.v_head.dropout.p, 0.5)
@parameterized.expand(ALL_SEQ2SEQ_MODELS)
def test_generate(self, model_name):
@@ -389,7 +407,7 @@ def test_raise_error_not_causallm(self):
# Test with a model without a LM head
model_id = "trl-internal-testing/tiny-random-T5Model"
# This should raise a ValueError
- with pytest.raises(ValueError):
+ with self.assertRaises(ValueError):
pretrained_model = AutoModel.from_pretrained(model_id)
_ = self.trl_model_class.from_pretrained(pretrained_model)
@@ -404,12 +422,13 @@ def test_push_to_hub(self):
model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo")
# check all keys
- assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
+ self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys())
for name, param in model.state_dict().items():
- assert torch.allclose(
- param, model_from_pretrained.state_dict()[name]
- ), f"Parameter {name} is not the same after push_to_hub and from_pretrained"
+ self.assertTrue(
+ torch.allclose(param, model_from_pretrained.state_dict()[name]),
+ f"Parameter {name} is not the same after push_to_hub and from_pretrained",
+ )
def test_transformers_bf16_kwargs(self):
r"""
@@ -427,11 +446,13 @@ def test_transformers_bf16_kwargs(self):
# skip the test for FSMT as it does not support mixed-prec
continue
- assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
+ self.assertTrue(
+ any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
+ )
for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
- assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16
+ self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)
dummy_input = torch.LongTensor([[0, 1, 0, 1]])
@@ -471,14 +492,16 @@ def test_independent_reference(self):
last_ref_layer_after = ref_model.get_parameter(layer_5).data.clone()
# before optimization ref and model are identical
- assert (first_layer_before == first_ref_layer_before).all()
- assert (last_layer_before == last_ref_layer_before).all()
+ self.assertTrue((first_layer_before == first_ref_layer_before).all())
+ self.assertTrue((last_layer_before == last_ref_layer_before).all())
+
# ref model stays identical after optimization
- assert (first_ref_layer_before == first_ref_layer_after).all()
- assert (last_ref_layer_before == last_ref_layer_after).all()
+ self.assertTrue((first_ref_layer_before == first_ref_layer_after).all())
+ self.assertTrue((last_ref_layer_before == last_ref_layer_after).all())
+
# optimized model changes
- assert not (first_layer_before == first_layer_after).all()
- assert not (last_layer_before == last_layer_after).all()
+ self.assertFalse((first_layer_before == first_layer_after).all())
+ self.assertFalse((last_layer_before == last_layer_after).all())
def test_shared_layers(self):
layer_0 = self.layer_format.format(layer=0)
@@ -503,12 +526,15 @@ def test_shared_layers(self):
second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone()
# before optimization ref and model are identical
- assert (first_layer_before == first_ref_layer_before).all()
- assert (second_layer_before == second_ref_layer_before).all()
+ self.assertTrue((first_layer_before == first_ref_layer_before).all())
+ self.assertTrue((second_layer_before == second_ref_layer_before).all())
+
# ref model stays identical after optimization
- assert (first_ref_layer_before == first_ref_layer_after).all()
- assert (second_ref_layer_before == second_ref_layer_after).all()
+ self.assertTrue((first_ref_layer_before == first_ref_layer_after).all())
+ self.assertTrue((second_ref_layer_before == second_ref_layer_after).all())
+
# first layer of optimized model stays the same
- assert (first_layer_before == first_layer_after).all()
+ self.assertTrue((first_layer_before == first_layer_after).all())
+
# other layers in optimized model change
- assert not (second_layer_before == second_layer_after).all()
+ self.assertFalse((second_layer_before == second_layer_after).all())
diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py
index 962fbd2885..133e8ac6eb 100644
--- a/tests/test_orpo_trainer.py
+++ b/tests/test_orpo_trainer.py
@@ -79,14 +79,14 @@ def test_orpo_trainer(self, name, config_name):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.equal(param, new_param)
+ self.assertFalse(torch.equal(param, new_param))
@require_peft
@parameterized.expand(
@@ -136,7 +136,7 @@ def test_orpo_trainer_with_lora(self, config_name):
trainer.train()
- assert trainer.state.log_history[-1]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
@@ -144,4 +144,4 @@ def test_orpo_trainer_with_lora(self, config_name):
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
- assert not torch.equal(param, new_param)
+ self.assertFalse(torch.equal(param, new_param))
diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py
index 04539dbcf4..d37058b839 100644
--- a/tests/test_peft_models.py
+++ b/tests/test_peft_models.py
@@ -61,7 +61,7 @@ def test_peft_requires_grad(self):
model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model)
# Check that the value head has requires_grad=True
- assert model.v_head.summary.weight.requires_grad
+ self.assertTrue(model.v_head.summary.weight.requires_grad)
def test_check_peft_model_nb_trainable_params(self):
r"""
@@ -74,12 +74,12 @@ def test_check_peft_model_nb_trainable_params(self):
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- assert nb_trainable_params == 10273
+ self.assertEqual(nb_trainable_params, 10273)
# Check that the number of trainable param for the non-peft model is correct
non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad)
- assert nb_trainable_params == 99578
+ self.assertEqual(nb_trainable_params, 99578)
def test_create_peft_model_from_config(self):
r"""
@@ -90,13 +90,13 @@ def test_create_peft_model_from_config(self):
)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
- assert nb_trainable_params == 10273
+ self.assertEqual(nb_trainable_params, 10273)
causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
- assert nb_trainable_params == 10273
+ self.assertEqual(nb_trainable_params, 10273)
@require_torch_gpu_if_bnb_not_multi_backend_enabled
def test_create_bnb_peft_model_from_config(self):
@@ -110,8 +110,8 @@ def test_create_bnb_peft_model_from_config(self):
)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
- assert nb_trainable_params == 10273
- assert trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
+ self.assertEqual(nb_trainable_params, 10273)
+ self.assertEqual(trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__, Linear8bitLt)
causal_lm_model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id, load_in_8bit=True, device_map="auto"
@@ -119,8 +119,8 @@ def test_create_bnb_peft_model_from_config(self):
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
- assert nb_trainable_params == 10273
- assert trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
+ self.assertEqual(nb_trainable_params, 10273)
+ self.assertEqual(trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__, Linear8bitLt)
def test_save_pretrained_peft(self):
r"""
@@ -135,23 +135,31 @@ def test_save_pretrained_peft(self):
model.save_pretrained(tmp_dir)
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
- assert os.path.isfile(
- f"{tmp_dir}/adapter_model.safetensors"
- ), f"{tmp_dir}/adapter_model.safetensors does not exist"
- assert os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
+ self.assertTrue(
+ os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
+ f"{tmp_dir}/adapter_model.safetensors does not exist",
+ )
+ self.assertTrue(
+ os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
+ )
+
# check also for `pytorch_model.bin` and make sure it only contains `v_head` weights
- assert os.path.exists(f"{tmp_dir}/pytorch_model.bin"), f"{tmp_dir}/pytorch_model.bin does not exist"
- maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin", weights_only=True)
+ self.assertTrue(
+ os.path.exists(f"{tmp_dir}/pytorch_model.bin"), f"{tmp_dir}/pytorch_model.bin does not exist"
+ )
+
# check that only keys that starts with `v_head` are in the dict
- assert all(
- k.startswith("v_head") for k in maybe_v_head.keys()
- ), f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`"
+ maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin", weights_only=True)
+ self.assertTrue(
+ all(k.startswith("v_head") for k in maybe_v_head.keys()),
+ f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`",
+ )
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir)
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
- assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
+ self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}")
def test_load_pretrained_peft(self):
r"""
@@ -167,15 +175,18 @@ def test_load_pretrained_peft(self):
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir)
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
- assert os.path.isfile(
- f"{tmp_dir}/adapter_model.safetensors"
- ), f"{tmp_dir}/adapter_model.safetensors does not exist"
- assert os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
+ self.assertTrue(
+ os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
+ f"{tmp_dir}/adapter_model.safetensors does not exist",
+ )
+ self.assertTrue(
+ os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
+ )
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]:
- assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
+ self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}")
def test_continue_training_peft_model(self):
r"""
@@ -190,4 +201,4 @@ def test_continue_training_peft_model(self):
model = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir, is_trainable=True)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- assert nb_trainable_params == 10273
+ self.assertEqual(nb_trainable_params, 10273)
diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py
index 4e03dd6ff9..a56b037bef 100644
--- a/tests/test_reward_trainer.py
+++ b/tests/test_reward_trainer.py
@@ -39,7 +39,7 @@ def setUp(self):
def test_accuracy_metrics(self):
dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0]))
accuracy = compute_accuracy(dummy_eval_predictions)
- assert accuracy["accuracy"] == 0.5
+ self.assertEqual(accuracy["accuracy"], 0.5)
def test_preprocessing_conversational(self):
with tempfile.TemporaryDirectory() as tmp_dir:
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 7901bf20b4..e02a2cb418 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -51,38 +51,6 @@ def test():
)
-def test_rloo_reward():
- local_batch_size = 3
- rloo_k = 4
- # fmt: off
- rlhf_reward = torch.tensor([
- 1, 2, 3, # first rlhf reward for three prompts
- 2, 3, 4, # second rlhf reward for three prompts
- 5, 6, 7, # third rlhf reward for three prompts
- 8, 9, 10, # fourth rlhf reward for three prompts
- ]).float()
- # fmt: on
-
- baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
- advantages = torch.zeros_like(rlhf_reward)
- for i in range(0, len(advantages), local_batch_size):
- other_response_rlhf_rewards = []
- for j in range(0, len(advantages), local_batch_size):
- if i != j:
- other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
- advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(
- other_response_rlhf_rewards
- ).mean(0)
- assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6
- assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6
-
- # vectorized impl
- rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
- baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
- vec_advantages = rlhf_reward - baseline
- torch.testing.assert_close(vec_advantages.flatten(), advantages)
-
-
class RLOOTrainerTester(unittest.TestCase):
def setUp(self):
self.sft_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
@@ -120,3 +88,34 @@ def test_rloo_checkpoint(self):
)
trainer._save_checkpoint(trainer.model, trial=None)
+
+ def test_rloo_reward(self):
+ local_batch_size = 3
+ rloo_k = 4
+ # fmt: off
+ rlhf_reward = torch.tensor([
+ 1, 2, 3, # first rlhf reward for three prompts
+ 2, 3, 4, # second rlhf reward for three prompts
+ 5, 6, 7, # third rlhf reward for three prompts
+ 8, 9, 10, # fourth rlhf reward for three prompts
+ ]).float()
+ # fmt: on
+
+ baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
+ advantages = torch.zeros_like(rlhf_reward)
+ for i in range(0, len(advantages), local_batch_size):
+ other_response_rlhf_rewards = []
+ for j in range(0, len(advantages), local_batch_size):
+ if i != j:
+ other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
+ advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(
+ other_response_rlhf_rewards
+ ).mean(0)
+ self.assertLess((1 - (2 + 5 + 8) / 3 - advantages[0].item()), 1e-6)
+ self.assertLess((6 - (3 + 2 + 9) / 3 - advantages[7].item()), 1e-6)
+
+ # vectorized impl
+ rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
+ baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
+ vec_advantages = rlhf_reward - baseline
+ torch.testing.assert_close(vec_advantages.flatten(), advantages)
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index 61490b7f5a..e615caf043 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -17,7 +17,6 @@
import unittest
import numpy as np
-import pytest
import torch
from datasets import Dataset, Image, Sequence, load_dataset
from transformers import (
@@ -166,18 +165,18 @@ def test_constant_length_dataset(self):
formatting_func=formatting_prompts_func,
)
- assert len(formatted_dataset) == len(self.dummy_dataset)
- assert len(formatted_dataset) > 0
+ self.assertEqual(len(formatted_dataset), len(self.dummy_dataset))
+ self.assertGreater(len(formatted_dataset), 0)
for example in formatted_dataset:
- assert "input_ids" in example
- assert "labels" in example
+ self.assertIn("input_ids", example)
+ self.assertIn("labels", example)
- assert len(example["input_ids"]) == formatted_dataset.seq_length
- assert len(example["labels"]) == formatted_dataset.seq_length
+ self.assertEqual(len(example["input_ids"]), formatted_dataset.seq_length)
+ self.assertEqual(len(example["labels"]), formatted_dataset.seq_length)
decoded_text = self.tokenizer.decode(example["input_ids"])
- assert ("Question" in decoded_text) and ("Answer" in decoded_text)
+ self.assertTrue(("Question" in decoded_text) and ("Answer" in decoded_text))
def test_sft_trainer_backward_compatibility(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -200,14 +199,14 @@ def test_sft_trainer_backward_compatibility(self):
formatting_func=formatting_prompts_func,
)
- assert trainer.args.hub_token == training_args.hub_token
+ self.assertEqual(trainer.args.hub_token, training_args.hub_token)
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
def test_sft_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -232,10 +231,10 @@ def test_sft_trainer(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
def test_sft_trainer_uncorrect_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -341,7 +340,7 @@ def test_sft_trainer_uncorrect_data(self):
packing=True,
report_to="none",
)
- with pytest.raises(ValueError):
+ with self.assertRaises(ValueError):
_ = SFTTrainer(
model=self.model,
args=training_args,
@@ -350,7 +349,7 @@ def test_sft_trainer_uncorrect_data(self):
)
# This should not work as well
- with pytest.raises(ValueError):
+ with self.assertRaises(ValueError):
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
@@ -409,10 +408,10 @@ def test_sft_trainer_with_model_num_train_epochs(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
@@ -435,9 +434,9 @@ def test_sft_trainer_with_model_num_train_epochs(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
@@ -458,9 +457,9 @@ def test_sft_trainer_with_model_num_train_epochs(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-1"))
def test_sft_trainer_with_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -484,10 +483,10 @@ def test_sft_trainer_with_model(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
@@ -509,9 +508,9 @@ def test_sft_trainer_with_model(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -535,9 +534,9 @@ def test_sft_trainer_with_model(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -559,9 +558,9 @@ def test_sft_trainer_with_model(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
@@ -581,9 +580,9 @@ def test_sft_trainer_with_model(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-1"))
def test_sft_trainer_with_multiple_eval_datasets(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -611,11 +610,11 @@ def test_sft_trainer_with_multiple_eval_datasets(self):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_data1_loss"] is not None
- assert trainer.state.log_history[1]["eval_data2_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_data1_loss"])
+ self.assertIsNotNone(trainer.state.log_history[1]["eval_data2_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-1"))
def test_data_collator_completion_lm(self):
response_template = "### Response:\n"
@@ -630,7 +629,7 @@ def test_data_collator_completion_lm(self):
labels = batch["labels"]
last_pad_idx = np.where(labels == -100)[1][-1]
result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :])
- assert result_text == "I have not been masked correctly."
+ self.assertEqual(result_text, "I have not been masked correctly.")
def test_data_collator_completion_lm_with_multiple_text(self):
tokenizer = copy.deepcopy(self.tokenizer)
@@ -653,7 +652,7 @@ def test_data_collator_completion_lm_with_multiple_text(self):
labels = batch["labels"][i]
last_pad_idx = np.where(labels == -100)[0][-1]
result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :])
- assert result_text == "I have not been masked correctly."
+ self.assertEqual(result_text, "I have not been masked correctly.")
def test_data_collator_chat_completion_lm(self):
instruction_template = "### Human:"
@@ -674,7 +673,7 @@ def test_data_collator_chat_completion_lm(self):
labels = batch["labels"]
non_masked_tokens = batch["input_ids"][labels != -100]
result_text = self.tokenizer.decode(non_masked_tokens)
- assert result_text == " I should not be masked. I should not be masked too."
+ self.assertEqual(result_text, " I should not be masked. I should not be masked too.")
def test_data_collator_chat_completion_lm_with_multiple_text(self):
tokenizer = copy.deepcopy(self.tokenizer)
@@ -702,11 +701,11 @@ def test_data_collator_chat_completion_lm_with_multiple_text(self):
non_masked_tokens1 = input_ids[0][labels[0] != -100]
result_text1 = tokenizer.decode(non_masked_tokens1)
- assert result_text1 == " I should not be masked."
+ self.assertEqual(result_text1, " I should not be masked.")
non_masked_tokens2 = input_ids[1][labels[1] != -100]
result_text2 = tokenizer.decode(non_masked_tokens2)
- assert result_text2 == " I should not be masked. I should not be masked too."
+ self.assertEqual(result_text2, " I should not be masked. I should not be masked too.")
def test_sft_trainer_infinite_with_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -729,15 +728,15 @@ def test_sft_trainer_infinite_with_model(self):
eval_dataset=self.eval_dataset,
)
- assert trainer.train_dataset.infinite
+ self.assertTrue(trainer.train_dataset.infinite)
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
# make sure the trainer did 5 steps
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-5")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-5"))
def test_sft_trainer_infinite_with_model_epochs(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -758,14 +757,14 @@ def test_sft_trainer_infinite_with_model_epochs(self):
eval_dataset=self.eval_dataset,
)
- assert not trainer.train_dataset.infinite
+ self.assertFalse(trainer.train_dataset.infinite)
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
# make sure the trainer did 5 steps
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-4")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-4"))
def test_sft_trainer_with_model_neftune(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -799,8 +798,8 @@ def test_sft_trainer_with_model_neftune(self):
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
- assert not torch.allclose(embeds_neftune, embeds_neftune_2)
- assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0
+ self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
+ self.assertGreater(len(trainer.model.get_input_embeddings()._forward_hooks), 0)
trainer.neftune_hook_handle.remove()
@@ -808,7 +807,7 @@ def test_sft_trainer_with_model_neftune(self):
# Make sure forward pass works fine
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
- assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0
+ self.assertEqual(len(trainer.model.get_input_embeddings()._forward_hooks), 0)
@require_peft
def test_peft_sft_trainer_str(self):
@@ -866,16 +865,16 @@ def test_peft_sft_trainer(self):
peft_config=peft_config,
)
- assert isinstance(trainer.model, PeftModel)
+ self.assertTrue(isinstance(trainer.model, PeftModel))
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
- assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
- assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("adapter_model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
+ self.assertIn("adapter_config.json", os.listdir(tmp_dir + "/checkpoint-2"))
+ self.assertNotIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
@require_peft
def test_peft_sft_trainer_gc(self):
@@ -909,16 +908,16 @@ def test_peft_sft_trainer_gc(self):
peft_config=peft_config,
)
- assert isinstance(trainer.model, PeftModel)
+ self.assertIsInstance(trainer.model, PeftModel)
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
- assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
- assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("adapter_model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
+ self.assertIn("adapter_config.json", os.listdir(tmp_dir + "/checkpoint-2"))
+ self.assertNotIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
@require_peft
def test_peft_sft_trainer_neftune(self):
@@ -954,7 +953,7 @@ def test_peft_sft_trainer_neftune(self):
trainer.model = trainer._activate_neftune(trainer.model)
- assert isinstance(trainer.model, PeftModel)
+ self.assertIsInstance(trainer.model, PeftModel)
device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()
@@ -965,23 +964,23 @@ def test_peft_sft_trainer_neftune(self):
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
- assert not torch.allclose(embeds_neftune, embeds_neftune_2)
- assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0
+ self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
+ self.assertGreater(len(trainer.model.get_input_embeddings()._forward_hooks), 0)
trainer.neftune_hook_handle.remove()
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
- assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
- assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("adapter_model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
+ self.assertIn("adapter_config.json", os.listdir(tmp_dir + "/checkpoint-2"))
+ self.assertNotIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
# Make sure forward pass works fine to check if embeddings forward is not broken.
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
- assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0
+ self.assertEqual(len(trainer.model.get_input_embeddings()._forward_hooks), 0)
@require_peft
def test_peft_sft_trainer_tag(self):
@@ -1068,8 +1067,8 @@ def test_sft_trainer_only_train_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)
- assert len(trainer.train_dataset["input_ids"]) == 21 # with the used dataset, we end up with 21 sequences
- assert len(trainer.eval_dataset["input_ids"]) == len(self.conversational_lm_dataset["test"])
+ self.assertEqual(len(trainer.train_dataset["input_ids"]), 21)
+ self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))
def test_sft_trainer_eval_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1093,8 +1092,8 @@ def test_sft_trainer_eval_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)
- assert len(trainer.train_dataset["input_ids"]) == 21 # with the used dataset, we end up with 21 sequences
- assert len(trainer.eval_dataset["input_ids"]) == 2 # with the used dataset, we end up with 2 sequence
+ self.assertEqual(len(trainer.train_dataset["input_ids"]), 21)
+ self.assertEqual(len(trainer.eval_dataset["input_ids"]), 2)
def test_sft_trainer_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1118,8 +1117,8 @@ def test_sft_trainer_no_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)
- assert len(trainer.train_dataset["input_ids"]) == len(self.conversational_lm_dataset["train"])
- assert len(trainer.eval_dataset["input_ids"]) == len(self.conversational_lm_dataset["test"])
+ self.assertEqual(len(trainer.train_dataset["input_ids"]), len(self.conversational_lm_dataset["train"]))
+ self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))
@require_vision
def test_sft_trainer_skip_prepare_dataset(self):
@@ -1144,8 +1143,8 @@ def test_sft_trainer_skip_prepare_dataset(self):
train_dataset=self.dummy_vsft_instruction_dataset,
eval_dataset=self.dummy_vsft_instruction_dataset,
)
- assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features
- assert trainer.eval_dataset.features == self.dummy_vsft_instruction_dataset.features
+ self.assertEqual(trainer.train_dataset.features, self.dummy_vsft_instruction_dataset.features)
+ self.assertEqual(trainer.eval_dataset.features, self.dummy_vsft_instruction_dataset.features)
def test_sft_trainer_skip_prepare_dataset_with_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1168,7 +1167,7 @@ def test_sft_trainer_skip_prepare_dataset_with_no_packing(self):
args=training_args,
train_dataset=self.dummy_dataset,
)
- assert trainer.train_dataset.features == self.dummy_dataset.features
+ self.assertEqual(trainer.train_dataset.features, self.dummy_dataset.features)
@require_vision
def test_sft_trainer_llava(self):
@@ -1218,10 +1217,10 @@ def collate_fn(examples):
trainer.train()
- assert trainer.state.log_history[(-1)]["train_loss"] is not None
- assert trainer.state.log_history[0]["eval_loss"] is not None
+ self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
+ self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
- assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
+ self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
def test_sft_trainer_torch_dtype(self):
# See https://github.com/huggingface/trl/issues/1751
@@ -1243,7 +1242,7 @@ def test_sft_trainer_torch_dtype(self):
eval_dataset=self.eval_dataset,
formatting_func=formatting_prompts_func,
)
- assert trainer.model.config.torch_dtype == torch.float16
+ self.assertEqual(trainer.model.config.torch_dtype, torch.float16)
# Now test when `torch_dtype` is provided but is wrong
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1257,13 +1256,15 @@ def test_sft_trainer_torch_dtype(self):
model_init_kwargs={"torch_dtype": -1},
report_to="none",
)
- with pytest.raises(
- ValueError,
- match="Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
- ):
+ with self.assertRaises(ValueError) as context:
_ = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
)
+
+ self.assertIn(
+ "Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
+ str(context.exception),
+ )
diff --git a/tests/test_utils.py b/tests/test_utils.py
index ba605ab086..1c97f50137 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -151,14 +151,14 @@ def test_full(self):
paper_id="1234.56789",
)
card_text = str(model_card)
- assert "[username/my_base_model](https://huggingface.co/username/my_base_model)" in card_text
- assert "my_model" in card_text
- assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
- assert "datasets: username/my_dataset" in card_text
- assert "](https://wandb.ai/username/project_id/runs/abcd1234)" in card_text
- assert "My Trainer" in card_text
- assert "```bibtex\n@article{my_trainer, ...}\n```" in card_text
- assert "[My Paper](https://huggingface.co/papers/1234.56789)" in card_text
+ self.assertIn("[username/my_base_model](https://huggingface.co/username/my_base_model)", card_text)
+ self.assertIn("my_model", card_text)
+ self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text)
+ self.assertIn("datasets: username/my_dataset", card_text)
+ self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text)
+ self.assertIn("My Trainer", card_text)
+ self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text)
+ self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text)
def test_val_none(self):
model_card = generate_model_card(
@@ -174,9 +174,9 @@ def test_val_none(self):
paper_id=None,
)
card_text = str(model_card)
- assert "my_model" in card_text
- assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
- assert "My Trainer" in card_text
+ self.assertIn("my_model", card_text)
+ self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text)
+ self.assertIn("My Trainer", card_text)
class TestDataCollatorForChatML(unittest.TestCase):