From 015321e1350787779c4bcbf3b0ad63b53c64ee07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 11 Nov 2024 19:06:20 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=88=20Add=20`tokenizer`=20arg=20back?= =?UTF-8?q?=20and=20add=20deprecation=20guidelines=20(#2348)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add deprecation and backward compatibility guidelines * Update tokenizer argument in trainer classes * Add warning message for TRL Judges API --- CONTRIBUTING.md | 27 +++++++++++++++++++++++++++ docs/source/judges.mdx | 6 ++++++ trl/trainer/bco_trainer.py | 2 ++ trl/trainer/cpo_trainer.py | 2 ++ trl/trainer/dpo_trainer.py | 2 +- trl/trainer/gkd_trainer.py | 2 ++ trl/trainer/iterative_sft_trainer.py | 2 ++ trl/trainer/kto_trainer.py | 2 ++ trl/trainer/nash_md_trainer.py | 2 ++ trl/trainer/online_dpo_trainer.py | 2 ++ trl/trainer/orpo_trainer.py | 2 ++ trl/trainer/ppo_trainer.py | 2 ++ trl/trainer/reward_trainer.py | 2 ++ trl/trainer/rloo_trainer.py | 2 ++ trl/trainer/sft_trainer.py | 2 +- trl/trainer/xpo_trainer.py | 2 ++ 16 files changed, 59 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a831cb3538..3e27528c14 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -256,3 +256,30 @@ That's how `make test` is implemented (without the `pip install` line)! You can specify a smaller set of tests to test only the feature you're working on. + +### Deprecation and Backward Compatibility + +Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs. + +When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include: + +- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement. +- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition. + +Example: + + ```python + warnings.warn( + "The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. " + "Please use the `Trainer.bar` class instead.", + FutureWarning, + ) + ``` + +The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes: + +- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next. + +- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning. + +These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs. diff --git a/docs/source/judges.mdx b/docs/source/judges.mdx index 3e1cda6ba8..329cc6c917 100644 --- a/docs/source/judges.mdx +++ b/docs/source/judges.mdx @@ -1,5 +1,11 @@ # Judges + + +TRL Judges is an experimental API which is subject to change at any time. + + + TRL provides judges to easily compare two completions. Make sure to have installed the required dependencies by running: diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index baf43a42bb..967373713d 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -48,6 +48,7 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available +from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template from ..models import PreTrainedModelWrapper, create_reference_model @@ -317,6 +318,7 @@ class BCOTrainer(Trainer): _tag_names = ["trl", "bco"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True) def __init__( self, model: Union[PreTrainedModel, nn.Module, str] = None, diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 5e74fdaceb..e5fcef01db 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -44,6 +44,7 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy +from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from .cpo_config import CPOConfig @@ -103,6 +104,7 @@ class CPOTrainer(Trainer): _tag_names = ["trl", "cpo"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index a3d7b1e974..21382db312 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -186,7 +186,7 @@ class DPOTrainer(Trainer): _tag_names = ["trl", "dpo"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index f44335d197..199a11bac6 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -37,6 +37,7 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from transformers.utils import is_liger_kernel_available, is_peft_available +from transformers.utils.deprecation import deprecate_kwarg from ..models import PreTrainedModelWrapper from ..models.utils import unwrap_model_for_generation @@ -61,6 +62,7 @@ class GKDTrainer(SFTTrainer): _tag_names = ["trl", "gkd"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py index 6e81a3586f..582872fb47 100644 --- a/trl/trainer/iterative_sft_trainer.py +++ b/trl/trainer/iterative_sft_trainer.py @@ -33,6 +33,7 @@ ) from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available +from transformers.utils.deprecation import deprecate_kwarg from ..core import PPODecorators from .utils import generate_model_card @@ -80,6 +81,7 @@ class IterativeSFTTrainer(Trainer): _tag_names = ["trl", "iterative-sft"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True) def __init__( self, model: Optional[PreTrainedModel] = None, diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index e6991182a7..7be86d924e 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -47,6 +47,7 @@ ) from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available +from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset from ..models import PreTrainedModelWrapper, create_reference_model @@ -312,6 +313,7 @@ class KTOTrainer(Trainer): _tag_names = ["trl", "kto"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) def __init__( self, model: Union[PreTrainedModel, nn.Module, str] = None, diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index c998174765..b41a426e08 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -33,6 +33,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.training_args import OptimizerNames from transformers.utils import is_apex_available +from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import is_conversational, maybe_apply_chat_template from ..models.modeling_base import GeometricMixtureWrapper @@ -93,6 +94,7 @@ class NashMDTrainer(OnlineDPOTrainer): _tag_names = ["trl", "nash-md"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True) def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 790e546387..eb6a07db7e 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -44,6 +44,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker from transformers.training_args import OptimizerNames from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging +from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..models import create_reference_model @@ -125,6 +126,7 @@ class OnlineDPOTrainer(Trainer): _tag_names = ["trl", "online-dpo"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) def __init__( self, model: Union[PreTrainedModel, nn.Module], diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index ccbe254019..e4d3265dd5 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -48,6 +48,7 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy +from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper @@ -114,6 +115,7 @@ class ORPOTrainer(Trainer): _tag_names = ["trl", "orpo"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 872f83c135..8e2c514a54 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -44,6 +44,7 @@ from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback +from transformers.utils.deprecation import deprecate_kwarg from ..core import masked_mean, masked_whiten from ..models.utils import unwrap_model_for_generation @@ -90,6 +91,7 @@ def forward(self, **kwargs): class PPOTrainer(Trainer): _tag_names = ["trl", "ppo"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True) def __init__( self, config: PPOConfig, diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 0ebdee68b4..f3456cab41 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -39,6 +39,7 @@ from transformers.trainer_pt_utils import nested_detach from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available +from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template from .reward_config import RewardConfig @@ -80,6 +81,7 @@ def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase") class RewardTrainer(Trainer): _tag_names = ["trl", "reward-trainer"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module]] = None, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index e33899f5d9..2bce1add0f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -44,6 +44,7 @@ from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback +from transformers.utils.deprecation import deprecate_kwarg from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( @@ -71,6 +72,7 @@ class RLOOTrainer(Trainer): _tag_names = ["trl", "rloo"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) def __init__( self, config: RLOOConfig, diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 59f445bcc7..2dce57b12b 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -106,7 +106,7 @@ class SFTTrainer(Trainer): _tag_names = ["trl", "sft"] - @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True) + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True) def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 4ec501c7f0..431f3d79f1 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -33,6 +33,7 @@ ) from transformers.trainer_utils import EvalPrediction from transformers.training_args import OptimizerNames +from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import is_conversational, maybe_apply_chat_template from ..models.utils import unwrap_model_for_generation @@ -92,6 +93,7 @@ class XPOTrainer(OnlineDPOTrainer): _tag_names = ["trl", "xpo"] + @deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True) def __init__( self, model: Union[PreTrainedModel, nn.Module] = None,