From 3c7724d67f30710fd4f41a83c20dc5ceb3ba6d42 Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Fri, 29 Nov 2024 11:27:11 +0100 Subject: [PATCH] Pass in text_kwargs --- .../models/oneformer/processing_oneformer.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/oneformer/processing_oneformer.py b/src/transformers/models/oneformer/processing_oneformer.py index d30091c69242b0..96badf910e8e40 100644 --- a/src/transformers/models/oneformer/processing_oneformer.py +++ b/src/transformers/models/oneformer/processing_oneformer.py @@ -50,7 +50,12 @@ class OneFormerImagesKwargs(ImagesKwargs): class OneFormerProcessorKwargs(ProcessingKwargs, total=False): images_kwargs: OneFormerImagesKwargs _defaults = { - "text_kwargs": {"max_seq_length": 77, "task_seq_length": 77}, + "text_kwargs": { + "max_seq_length": 77, + "task_seq_length": 77, + "padding": "max_length", + "truncation": True, + }, } @@ -88,8 +93,20 @@ def __init__( super().__init__(image_processor, tokenizer) - def _preprocess_text(self, text_list: PreTokenizedInput, max_length: int = 77): - tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True) + def _preprocess_text( + self, + text_list: PreTokenizedInput, + max_length: Optional[int] = None, + text_kwargs: Optional[dict] = None, + ): + if text_kwargs is None: + text_kwargs = {} + tokens = self.tokenizer( + text_list, + max_length=max_length if max_length is not None else text_kwargs.get("max_length", 77), + padding=text_kwargs.get("padding", "max_length"), + truncation=text_kwargs.get("truncation", True), + ) attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"] @@ -184,6 +201,7 @@ def __call__( encoded_inputs["task_inputs"] = self._preprocess_text( task_token_inputs, max_length=output_kwargs["text_kwargs"]["task_seq_length"], + text_kwargs=output_kwargs["text_kwargs"], ) if hasattr(encoded_inputs, "text_inputs"): @@ -191,6 +209,7 @@ def __call__( self._preprocess_text( texts, max_length=output_kwargs["text_kwargs"]["max_seq_length"], + text_kwargs=output_kwargs["text_kwargs"], ).unsqueeze(0) for texts in encoded_inputs.text_inputs ]