Skip to content

Commit

Permalink
Add post_process_image_text_to_text to idefics3, mllama, pixtral proc…
Browse files Browse the repository at this point in the history
…essor
  • Loading branch information
yonigozlan committed Oct 2, 2024
1 parent f66775a commit 4137b24
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
14 changes: 14 additions & 0 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,20 @@ def decode(self, *args, **kwargs):
decode_output = self.tokenizer.decode(*args, **kwargs)
return self._regex_to_remove_extra_special_tokens.sub("<image>", decode_output)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,22 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_i
if image_token in prompt:
warnings.warn(
f"The image token {image_token} is already present in the prompt. No need to manually add {image_token} in the prompt for this model."
f" Remove all {image_token} and adding ({image_token}) * image_seq_len at the start of the prompt."
f" Removing all {image_token} and adding ({image_token}) * image_seq_len * num_images at the start of the prompt."
)
prompt = prompt.replace(image_token, "")
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
Expand Down

0 comments on commit 4137b24

Please sign in to comment.