Skip to content

Commit

Permalink
🙅 Ensure dependency optionality (#2301)
Browse files Browse the repository at this point in the history
* Add conditional check for LLMBlender availability in test_judges.py

* Fix import issues and update test requirements

* Remove unused imports

* Add require_peft decorator to test cases

* Fix import_utils module to use correct package name for llm_blender
  • Loading branch information
qgallouedec authored Oct 31, 2024
1 parent 013a32b commit 73c3970
Show file tree
Hide file tree
Showing 20 changed files with 78 additions and 107 deletions.
2 changes: 1 addition & 1 deletion tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.testing_utils import (
require_bitsandbytes,
require_liger_kernel,
require_peft,
require_torch_accelerator,
require_torch_multi_accelerator,
Expand All @@ -32,7 +33,6 @@
from trl import SFTConfig, SFTTrainer
from trl.models.utils import setup_chat_format

from ..testing_utils import require_liger_kernel
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS


Expand Down
10 changes: 9 additions & 1 deletion tests/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from trl import BCOConfig, BCOTrainer
from trl.trainer.bco_trainer import _process_tokens, _tokenize

from .testing_utils import require_no_wandb
from .testing_utils import require_no_wandb, require_sklearn


class BCOTrainerTester(unittest.TestCase):
Expand Down Expand Up @@ -56,6 +56,7 @@ def setUp(self):
["gpt2", True, True, "conversational_unpaired_preference"],
]
)
@require_sklearn
def test_bco_trainer(self, name, pre_compute, eval_dataset, config_name):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
Expand Down Expand Up @@ -103,6 +104,7 @@ def test_bco_trainer(self, name, pre_compute, eval_dataset, config_name):
if param.sum() != 0:
self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))

@require_sklearn
def test_bco_trainer_with_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
Expand All @@ -123,6 +125,7 @@ def test_bco_trainer_with_ref_model_is_model(self):
train_dataset=dummy_dataset["train"],
)

@require_sklearn
def test_tokenize_and_process_tokens(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
Expand Down Expand Up @@ -185,6 +188,7 @@ def test_tokenize_and_process_tokens(self):
processed_dataset["completion_labels"][0], [-100, -100, -100, 318, 1365, 621, 8253, 13, 50256]
)

@require_sklearn
def test_bco_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
Expand Down Expand Up @@ -222,6 +226,7 @@ def test_bco_trainer_without_providing_ref_model(self):
if param.sum() != 0:
self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))

@require_sklearn
def test_bco_trainer_udm(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
Expand Down Expand Up @@ -269,6 +274,7 @@ def embed_prompt(input_ids, attention_mask, model):
if param.sum() != 0:
self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))

@require_sklearn
@require_peft
def test_bco_trainer_without_providing_ref_model_with_lora(self):
from peft import LoraConfig
Expand Down Expand Up @@ -319,6 +325,7 @@ def test_bco_trainer_without_providing_ref_model_with_lora(self):
if param.sum() != 0:
self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))

@require_sklearn
@require_no_wandb
def test_bco_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -350,6 +357,7 @@ def test_bco_trainer_generate_during_eval_no_wandb(self):
eval_dataset=dummy_dataset["test"],
)

@require_sklearn
@require_peft
def test_bco_lora_save(self):
from peft import LoraConfig, get_peft_model
Expand Down
9 changes: 7 additions & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
import unittest

from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
from transformers.testing_utils import require_wandb
from transformers.testing_utils import require_peft, require_wandb
from transformers.utils import is_peft_available

from trl import BasePairwiseJudge, LogCompletionsCallback, WinRateCallback


if is_peft_available():
from peft import LoraConfig


class HalfPairwiseJudge(BasePairwiseJudge):
"""Naive pairwise judge that always returns [1, 0]"""

Expand Down Expand Up @@ -128,6 +132,7 @@ def test_without_ref_model(self):
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
self.assertListEqual(winrate_history, self.expected_winrates)

@require_peft
def test_lora(self):
with tempfile.TemporaryDirectory() as tmp_dir:
peft_config = LoraConfig(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_cpo_trainer(self, name, loss_type, config_name):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_peft
@parameterized.expand(
[
("standard_preference",),
Expand All @@ -102,6 +101,7 @@ def test_cpo_trainer(self, name, loss_type, config_name):
("conversational_implicit_prompt_preference",),
]
)
@require_peft
def test_cpo_trainer_with_lora(self, config_name):
from peft import LoraConfig

Expand Down
1 change: 0 additions & 1 deletion tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class ApplyChatTemplateTester(unittest.TestCase):
"trl-internal-testing/tiny-random-gemma-2-9b-it",
"trl-internal-testing/tiny-random-Mistral-7B-Instruct-v0.1",
"trl-internal-testing/tiny-random-Mistral-7B-Instruct-v0.2",
"trl-internal-testing/tiny-random-Mistral-7B-Instruct-v0.3",
]

conversational_examples = [
Expand Down
8 changes: 7 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,31 @@
import torch
from datasets import Dataset, features, load_dataset
from parameterized import parameterized
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
is_vision_available,
)
from transformers.testing_utils import (
require_bitsandbytes,
require_peft,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_vision,
)

from trl import DPOConfig, DPOTrainer, FDivergenceType

from .testing_utils import require_no_wandb


if is_vision_available():
from PIL import Image


class TestTokenizeRow(unittest.TestCase):
def setUp(self):
# Set up the mock tokenizer with specific behaviors
Expand Down Expand Up @@ -1049,6 +1054,7 @@ def test_dpo_loss_js_div_f(self):
self.assertTrue(torch.isfinite(losses).cpu().numpy().all())


@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
@parameterized.expand(
[
Expand Down
11 changes: 7 additions & 4 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@

import unittest

from trl import HfPairwiseJudge, PairRMJudge, RandomPairwiseJudge, RandomRankJudge, is_llmblender_available
from trl import HfPairwiseJudge, PairRMJudge, RandomPairwiseJudge, RandomRankJudge, is_llm_blender_available

from .testing_utils import require_llm_blender


class TestJudges(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Initialize once to download the model. This ensures it’s downloaded before running tests, preventing issues
# where concurrent tests attempt to load the model while it’s still downloading.
PairRMJudge()
if is_llm_blender_available():
PairRMJudge()

def _get_prompts_and_completions(self):
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
Expand Down Expand Up @@ -53,7 +56,7 @@ def test_hugging_face_judge(self):
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
self.assertEqual(ranks, [0, 1])

@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
@require_llm_blender
def test_pair_rm_judge(self):
judge = PairRMJudge()
prompts, completions = self._get_prompts_and_completions()
Expand All @@ -62,7 +65,7 @@ def test_pair_rm_judge(self):
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
self.assertEqual(ranks, [0, 1])

@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
@require_llm_blender
def test_pair_rm_judge_return_scores(self):
judge = PairRMJudge()
prompts, completions = self._get_prompts_and_completions()
Expand Down
7 changes: 5 additions & 2 deletions tests/test_nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import NashMDConfig, NashMDTrainer, PairRMJudge, is_llmblender_available
from trl import NashMDConfig, NashMDTrainer, PairRMJudge

from .testing_utils import require_llm_blender


if is_peft_available():
Expand Down Expand Up @@ -125,6 +127,7 @@ def test_training_with_peft_and_ref_model(self):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

@require_peft
def test_training_with_peft_model_and_peft_config(self):
model_lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(self.model, model_lora_config)
Expand Down Expand Up @@ -156,8 +159,8 @@ def test_training_with_peft_model_and_peft_config(self):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
@require_llm_blender
def test_nash_md_trainer_judge_training(self, config_name):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = NashMDConfig(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import OnlineDPOConfig, OnlineDPOTrainer, RandomPairwiseJudge, is_llmblender_available
from trl import OnlineDPOConfig, OnlineDPOTrainer, RandomPairwiseJudge, is_llm_blender_available
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


Expand Down Expand Up @@ -210,7 +210,7 @@ def test_training_with_peft_model_and_peft_config(self):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
@unittest.skipIf(not is_llm_blender_available(), "llm-blender is not available")
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
def test_training_with_judge(self, config_name):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def test_orpo_trainer(self, name, config_name):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_peft
@parameterized.expand(
[
("standard_preference",),
Expand All @@ -97,6 +96,7 @@ def test_orpo_trainer(self, name, config_name):
("conversational_implicit_prompt_preference",),
]
)
@require_peft
def test_orpo_trainer_with_lora(self, config_name):
from peft import LoraConfig

Expand Down
3 changes: 3 additions & 0 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@
XPOTrainer,
)

from .testing_utils import require_sklearn


class TrainerArgTester(unittest.TestCase):
@require_sklearn
def test_bco(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train")
Expand Down
5 changes: 3 additions & 2 deletions tests/test_xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import RandomPairwiseJudge, XPOConfig, XPOTrainer, is_llmblender_available
from trl import RandomPairwiseJudge, XPOConfig, XPOTrainer, is_llm_blender_available


if is_peft_available():
Expand Down Expand Up @@ -125,6 +125,7 @@ def test_training_with_peft_and_ref_model(self):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

@require_peft
def test_training_with_peft_model_and_peft_config(self):
model_lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(self.model, model_lora_config)
Expand Down Expand Up @@ -156,7 +157,7 @@ def test_training_with_peft_model_and_peft_config(self):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

@unittest.skipIf(not is_llmblender_available(), "llm-blender is not available")
@unittest.skipIf(not is_llm_blender_available(), "llm-blender is not available")
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
def test_xpo_trainer_judge_training(self, config_name):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
17 changes: 12 additions & 5 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.
import unittest

from transformers import is_wandb_available
from transformers import is_sklearn_available, is_wandb_available

from trl import is_diffusers_available, is_liger_kernel_available
from trl import is_diffusers_available, is_llm_blender_available


def require_diffusers(test_case):
Expand All @@ -32,8 +32,15 @@ def require_no_wandb(test_case):
return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case)


def require_liger_kernel(test_case):
def require_sklearn(test_case):
"""
Decorator marking a test that requires liger_kernel. Skips the test if liger_kernel is not available.
Decorator marking a test that requires sklearn. Skips the test if sklearn is not available.
"""
return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
return unittest.skipUnless(is_sklearn_available(), "test requires sklearn")(test_case)


def require_llm_blender(test_case):
"""
Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available.
"""
return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case)
10 changes: 2 additions & 8 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
"import_utils": [
"is_deepspeed_available",
"is_diffusers_available",
"is_liger_kernel_available",
"is_llmblender_available",
"is_llm_blender_available",
],
"models": [
"SUPPORTED_ARCHITECTURES",
Expand Down Expand Up @@ -129,12 +128,7 @@
)
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import (
is_deepspeed_available,
is_diffusers_available,
is_liger_kernel_available,
is_llmblender_available,
)
from .import_utils import is_deepspeed_available, is_diffusers_available, is_llm_blender_available
from .models import (
SUPPORTED_ARCHITECTURES,
AutoModelForCausalLMWithValueHead,
Expand Down
Loading

0 comments on commit 73c3970

Please sign in to comment.