Skip to content

Commit

Permalink
Pass in text_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Nov 29, 2024
1 parent cbddaf4 commit 3c7724d
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions src/transformers/models/oneformer/processing_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ class OneFormerImagesKwargs(ImagesKwargs):
class OneFormerProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: OneFormerImagesKwargs
_defaults = {
"text_kwargs": {"max_seq_length": 77, "task_seq_length": 77},
"text_kwargs": {
"max_seq_length": 77,
"task_seq_length": 77,
"padding": "max_length",
"truncation": True,
},
}


Expand Down Expand Up @@ -88,8 +93,20 @@ def __init__(

super().__init__(image_processor, tokenizer)

def _preprocess_text(self, text_list: PreTokenizedInput, max_length: int = 77):
tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)
def _preprocess_text(
self,
text_list: PreTokenizedInput,
max_length: Optional[int] = None,
text_kwargs: Optional[dict] = None,
):
if text_kwargs is None:
text_kwargs = {}
tokens = self.tokenizer(
text_list,
max_length=max_length if max_length is not None else text_kwargs.get("max_length", 77),
padding=text_kwargs.get("padding", "max_length"),
truncation=text_kwargs.get("truncation", True),
)

attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]

Expand Down Expand Up @@ -184,13 +201,15 @@ def __call__(
encoded_inputs["task_inputs"] = self._preprocess_text(
task_token_inputs,
max_length=output_kwargs["text_kwargs"]["task_seq_length"],
text_kwargs=output_kwargs["text_kwargs"],
)

if hasattr(encoded_inputs, "text_inputs"):
text_inputs = [
self._preprocess_text(
texts,
max_length=output_kwargs["text_kwargs"]["max_seq_length"],
text_kwargs=output_kwargs["text_kwargs"],
).unsqueeze(0)
for texts in encoded_inputs.text_inputs
]
Expand Down

0 comments on commit 3c7724d

Please sign in to comment.