Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Improve multimodal processors - rely less on kwargs #28711

Draft
wants to merge 36 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
42ecf48
expand kwargs from align
molbap Jan 25, 2024
ccb2147
remove kwargs from altclip processor
molbap Jan 25, 2024
f999e0c
add explicit args for donut processor
molbap Jan 25, 2024
8fb3a6b
add explicit call to current processor for in context manager
molbap Jan 25, 2024
a90c766
format
molbap Jan 25, 2024
49cb6cc
remove unused kwargs
molbap Jan 25, 2024
3ac1c7e
move conditions for encodings
molbap Jan 25, 2024
7a819fd
improve flow over text/image
molbap Jan 25, 2024
9cc38b7
[breaking] pass explicit args to bridgetower
molbap Jan 25, 2024
ff6a950
wwsMerge branch 'main' into improve_multimodal_processors
molbap Jan 26, 2024
7db64a0
add default kwargs for BC
molbap Jan 26, 2024
41674d9
fix bridgetower
molbap Jan 26, 2024
618a687
debug bridgetower image proc
molbap Jan 26, 2024
f39cdc1
format
molbap Jan 26, 2024
9a6f97d
move kwargs message to info level
molbap Jan 26, 2024
380f82f
add debug messages
molbap Jan 26, 2024
75f15d3
fix arguments not being passed in bridgetower
molbap Jan 26, 2024
3df5faa
keep backwards compat for processing + modify testing args dict
molbap Feb 1, 2024
5ad0694
Merge branch 'main' into improve_multimodal_processors
molbap Feb 1, 2024
69e5a2d
fix quality
molbap Feb 1, 2024
68c2f40
log kwargs mismatch to info level
molbap Feb 1, 2024
e1e4084
fix quality
molbap Feb 1, 2024
bfa81e5
Merge branch 'main' into improve_multimodal_processors
molbap Feb 15, 2024
4b557b0
address comments
molbap Feb 15, 2024
b7fc377
fix typo
molbap Feb 15, 2024
270bb9e
fix expected tests for bridgetower
molbap Feb 16, 2024
94a1b75
fix conflicts
molbap Feb 26, 2024
6603bf0
Merge branch 'main' into improve_multimodal_processors
molbap Feb 26, 2024
004c961
fix valid processor keys
molbap Feb 26, 2024
c2e49f5
remove unused arg list
molbap Feb 26, 2024
79958b5
quality
molbap Feb 26, 2024
a36f524
Merge branch 'main' into improve_multimodal_processors
molbap Apr 17, 2024
3238dd3
skeleton draft - uniform processor call
molbap Apr 18, 2024
3afde22
fix quality
molbap Apr 18, 2024
eb99e29
add broken wav2vec audio processing
molbap Apr 25, 2024
c6afd63
Merge branch 'main' into improve_multimodal_processors
molbap Apr 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 82 additions & 6 deletions src/transformers/models/align/processing_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
"""


from typing import Dict, List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import PaddingStrategy, TensorType


class AlignProcessor(ProcessorMixin):
Expand All @@ -42,11 +45,49 @@ class AlignProcessor(ProcessorMixin):
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, padding="max_length", max_length=64, return_tensors=None, **kwargs):
def __call__(
self,
text=None,
images=None,
do_crop_margin: bool = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: "PILImageResampling" = None, # noqa: F821
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_pad: bool = None,
do_rescale: bool = None,
rescale_factor: Union[int, float] = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = "max_length",
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = 64,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
):
"""
Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` arguments to
EfficientNetImageProcessor's [`~EfficientNetImageProcessor.__call__`] if `images` is not `None`. Please refer
to the doctsring of the above two methods for more information.

Expand Down Expand Up @@ -86,11 +127,46 @@ def __call__(self, text=None, images=None, padding="max_length", max_length=64,

if text is not None:
encoding = self.tokenizer(
text, padding=padding, max_length=max_length, return_tensors=return_tensors, **kwargs
text,
text_pair=text_pair,
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
)

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(
images,
do_crop_margin=do_crop_margin,
do_resize=do_resize,
size=size,
resample=resample,
do_thumbnail=do_thumbnail,
do_align_long_axis=do_align_long_axis,
do_pad=do_pad,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
data_format=data_format,
input_data_format=input_data_format,
)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
Expand Down
94 changes: 85 additions & 9 deletions src/transformers/models/altclip/processing_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
Image/Text processor class for AltCLIP
"""
import warnings
from typing import Dict, List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import PaddingStrategy, TensorType


class AltCLIPProcessor(ProcessorMixin):
Expand All @@ -34,22 +36,21 @@ class AltCLIPProcessor(ProcessorMixin):
The image processor is a required input.
tokenizer ([`XLMRobertaTokenizerFast`], *optional*):
The tokenizer is a required input.
feature_extractor ([`CLIPFeatureExtractor`], *optional*):
The feature extractor is a deprecated input.
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")

def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
def __init__(self, image_processor=None, tokenizer=None, feature_extractor=None):
if "feature_extractor":
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")

image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
Expand All @@ -58,7 +59,45 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):

super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
def __call__(
self,
text=None,
images=None,
do_crop_margin: bool = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: "PILImageResampling" = None, # noqa: F821
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_pad: bool = None,
do_rescale: bool = None,
rescale_factor: Union[int, float] = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
):
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not
Expand Down Expand Up @@ -97,10 +136,47 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
raise ValueError("You have to specify either text or images. Both cannot be none.")

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
encoding = self.tokenizer(
text,
text_pair=text_pair,
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
)

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(
images,
do_crop_margin=do_crop_margin,
do_resize=do_resize,
size=size,
resample=resample,
do_thumbnail=do_thumbnail,
do_align_long_axis=do_align_long_axis,
do_pad=do_pad,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
data_format=data_format,
input_data_format=input_data_format,
)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
Expand Down
39 changes: 9 additions & 30 deletions src/transformers/models/blip/processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,31 +76,7 @@ def __call__(
if images is None and text is None:
raise ValueError("You have to specify either images or text.")

# Get only text
if images is None:
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding

# add pixel_values
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
text_encoding = None

if text is not None:
text_encoding = self.tokenizer(
Expand All @@ -121,13 +97,16 @@ def __call__(
return_tensors=return_tensors,
**kwargs,
)
else:
text_encoding = None

if text_encoding is not None:
encoding_image_processor.update(text_encoding)
# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
# else, return the text encoding.
if images is not None:
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
if text_encoding is not None:
encoding_image_processor.update(text_encoding)
return encoding_image_processor

return encoding_image_processor
return text_encoding

def batch_decode(self, *args, **kwargs):
"""
Expand Down
39 changes: 9 additions & 30 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,31 +78,7 @@ def __call__(
if images is None and text is None:
raise ValueError("You have to specify either images or text.")

# Get only text
if images is None:
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding

# add pixel_values
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
text_encoding = None

if text is not None:
text_encoding = self.tokenizer(
Expand All @@ -123,13 +99,16 @@ def __call__(
return_tensors=return_tensors,
**kwargs,
)
else:
text_encoding = None

if text_encoding is not None:
encoding_image_processor.update(text_encoding)
# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
# else, return the text encoding.
if images is not None:
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
if text_encoding is not None:
encoding_image_processor.update(text_encoding)
return encoding_image_processor

return encoding_image_processor
return text_encoding

# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
def batch_decode(self, *args, **kwargs):
Expand Down
Loading
Loading