diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 7c239e417b..8a81fe24aa 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -40,6 +40,6 @@ from .iterative_sft_trainer import IterativeSFTTrainer from .ppo_config import PPOConfig from .ppo_trainer import PPOTrainer +from .reward_config import RewardConfig from .reward_trainer import RewardTrainer, compute_accuracy from .sft_trainer import SFTTrainer -from .training_configs import RewardConfig diff --git a/trl/trainer/training_configs.py b/trl/trainer/reward_config.py similarity index 95% rename from trl/trainer/training_configs.py rename to trl/trainer/reward_config.py index 8341819c34..3d5fe4da32 100644 --- a/trl/trainer/training_configs.py +++ b/trl/trainer/reward_config.py @@ -1,6 +1,4 @@ -# coding=utf-8 -# coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Optional diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index f2af50b634..4d32bdcb81 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -25,7 +25,7 @@ from transformers.trainer_utils import EvalPrediction from ..import_utils import is_peft_available -from .training_configs import RewardConfig +from .reward_config import RewardConfig from .utils import RewardDataCollatorWithPadding, compute_accuracy