Skip to content

Commit

Permalink
Handle Trainer tokenizer kwarg deprecation with decorator (huggingf…
Browse files Browse the repository at this point in the history
…ace#33887)

* Handle deprecation with decorator

* Fix for seq2seq Trainer
  • Loading branch information
qubvel authored Oct 2, 2024
1 parent ee71c98 commit 2f25ab9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 19 deletions.
19 changes: 2 additions & 17 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@
logging,
strtobool,
)
from .utils.deprecation import deprecate_kwarg
from .utils.quantization_config import QuantizationMethod


Expand Down Expand Up @@ -326,11 +327,6 @@ class Trainer:
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
dataset prepending the dictionary key to the metric name.
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
interrupted training or reuse the fine-tuned model.
This is now deprecated.
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
Processing class used to process the data. If provided, will be used to automatically process the inputs
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
Expand Down Expand Up @@ -385,14 +381,14 @@ class Trainer:
# Those are used as methods of the Trainer in examples.
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state

@deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
Expand Down Expand Up @@ -437,17 +433,6 @@ def __init__(
# force device and distributed setup init explicitly
args._setup_devices

if tokenizer is not None:
if processing_class is not None:
raise ValueError(
"You cannot specify both `tokenizer` and `processing_class` at the same time. Please use `processing_class`."
)
warnings.warn(
"`tokenizer` is now deprecated and will be removed in v5, please use `processing_class` instead.",
FutureWarning,
)
processing_class = tokenizer

if model is None:
if model_init is not None:
self.model_init = model_init
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .trainer import Trainer
from .utils import logging
from .utils.deprecation import deprecate_kwarg


if TYPE_CHECKING:
Expand All @@ -43,14 +44,14 @@


class Seq2SeqTrainer(Trainer):
@deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
def __init__(
self,
model: Union["PreTrainedModel", nn.Module] = None,
args: "TrainingArguments" = None,
data_collator: Optional["DataCollator"] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
processing_class: Optional[
Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]
] = None,
Expand All @@ -66,7 +67,6 @@ def __init__(
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
processing_class=processing_class,
model_init=model_init,
compute_metrics=compute_metrics,
Expand Down

0 comments on commit 2f25ab9

Please sign in to comment.