Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Oct 25, 2024
1 parent e8150f2 commit b9eb103
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 31 deletions.
6 changes: 2 additions & 4 deletions src/transformers/models/donut/processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,8 @@ def __call__(

if images is not None:
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is not None and images is None:
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
elif text is not None:
if not legacy:
if text is not None:
if not legacy and images is not None:
output_kwargs["text_kwargs"].setdefault("add_special_tokens", False)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/fuyu/processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def post_process_image_text_to_text(self, generated_outputs):
Returns:
`List[str]`: The decoded text output.
"""
beginning_of_answer = self.tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
beginning_of_answer = self.tokenizer.convert_tokens_to_ids[BEGINNING_OF_ANSWER_STRING]
# get boa index for each outputted sequence tensor
# start all generated sequences from the beginning of the answer token, pad to have consistent length
unpadded_output_sequences = [
Expand Down
43 changes: 17 additions & 26 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,6 @@ def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image",
self.images = images


class ImageText:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We used this class
as the base pipeline does not support multiple inputs, so we need to convert multiple inputs to a single input."""

def __init__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], text: Union[str, List[str]]):
self.images = images
self.text = text


def retrieve_images_in_messages(
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
):
Expand Down Expand Up @@ -364,7 +355,7 @@ def __call__(
if nested_images:
results = []
for image_group, text_single in zip(images, text):
results.extend(super().__call__(ImageText(image_group, text_single), **kwargs))
results.extend(super().__call__({"images": image_group, "text": text_single}, **kwargs))
return results

# otherwise, we can flatten the images and text as we have a 1:1 relationship
Expand All @@ -376,10 +367,10 @@ def __call__(
results = []
while batching_index < len(images):
batch_results = super().__call__(
ImageText(
images[batching_index : batching_index + batch_size],
text[batching_index : batching_index + batch_size],
),
{
"images": images[batching_index : batching_index + batch_size],
"text": text[batching_index : batching_index + batch_size],
},
**kwargs,
)
results.extend(batch_results)
Expand All @@ -397,16 +388,6 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro
text = inputs
inputs_text = inputs
else:
# We have an ImageText or Chat inputs
images = inputs.images
if len(images) > 0:
if not isinstance(images, (list, tuple)):
images = load_image(images, timeout=timeout)
else:
images = [load_image(image, timeout=timeout) for image in images]
else:
images = None

if isinstance(inputs, Chat):
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
Expand All @@ -419,9 +400,19 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro
return_tensors=self.framework,
)
inputs_text = inputs
images = inputs.images
else:
text = inputs["text"]
inputs_text = inputs["text"]
images = inputs["images"]

if len(images) > 0:
if not isinstance(images, (list, tuple)):
images = load_image(images, timeout=timeout)
else:
images = [load_image(image, timeout=timeout) for image in images]
else:
text = inputs.text
inputs_text = inputs.text
images = None

# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:
Expand Down

0 comments on commit b9eb103

Please sign in to comment.