Skip to content

Commit

Permalink
♾️ Fix test generation max_new_tokens (#2272)
Browse files Browse the repository at this point in the history
* `eval_strategy="steps" if eval_dataset else "no"`

* tmp skip test

* drop `eval_strategy` in `test_sft_trainer_uncorrect_data`

* remove eval strategy

* Add parameterized test for generate method

* Revert "`eval_strategy="steps" if eval_dataset else "no"`"

This reverts commit 1e8b331.

* Revert "tmp skip test"

This reverts commit 44558f8.

* Revert "drop `eval_strategy` in `test_sft_trainer_uncorrect_data`"

This reverts commit a1ef701.

* Revert "remove eval strategy"

This reverts commit cb7fafa.

* style

* Refactor test_generate method in test_modeling_value_head.py

* `max_new_tokens=9`
  • Loading branch information
qgallouedec authored Oct 24, 2024
1 parent c2bb1ee commit e615974
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions tests/test_modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import pytest
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from parameterized import parameterized
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig

from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, create_reference_model

Expand Down Expand Up @@ -248,16 +249,17 @@ def test_dropout_kwargs(self):
# Check if v head of the model has the same dropout as the config
assert model.v_head.dropout.p == 0.5

def test_generate(self):
@parameterized.expand(ALL_CAUSAL_LM_MODELS)
def test_generate(self, model_name):
r"""
Test if `generate` works for every model
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
generation_config = GenerationConfig(max_new_tokens=9)
model = self.trl_model_class.from_pretrained(model_name)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

# Just check if the generation works
_ = model.generate(input_ids)
# Just check if the generation works
_ = model.generate(input_ids, generation_config=generation_config)

def test_raise_error_not_causallm(self):
# Test with a model without a LM head
Expand Down Expand Up @@ -370,17 +372,18 @@ def test_dropout_kwargs(self):
# Check if v head of the model has the same dropout as the config
assert model.v_head.dropout.p == 0.5

def test_generate(self):
@parameterized.expand(ALL_SEQ2SEQ_MODELS)
def test_generate(self, model_name):
r"""
Test if `generate` works for every model
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
generation_config = GenerationConfig(max_new_tokens=9)
model = self.trl_model_class.from_pretrained(model_name)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

# Just check if the generation works
_ = model.generate(input_ids, decoder_input_ids=decoder_input_ids)
# Just check if the generation works
_ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config)

def test_raise_error_not_causallm(self):
# Test with a model without a LM head
Expand Down

0 comments on commit e615974

Please sign in to comment.