Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLMs: major clean up 🧼 #34502

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 7 additions & 46 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2186,29 +2186,14 @@ def forward(

# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)

inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)

if self.config.use_decoder_only_language_model:
outputs = self.language_model(
Expand Down Expand Up @@ -2304,9 +2289,6 @@ def generate(
query_output = query_outputs.last_hidden_state

language_model_inputs = self.language_projection(query_output)
language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)

if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
Expand All @@ -2319,36 +2301,15 @@ def generate(
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
)

# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = (
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()

inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
inputs["input_ids"] = input_ids

outputs = self.language_model.generate(**inputs, **generate_kwargs)

return outputs


Expand Down
41 changes: 11 additions & 30 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
self.image_token = tokenizer.image_token
self.num_query_tokens = num_query_tokens

# We'll add the BOS manually as it has to be after image tokens
tokenizer.add_bos_token = False
self.bos_token = tokenizer.bos_token

super().__init__(image_processor, tokenizer)

def __call__(
Expand Down Expand Up @@ -118,49 +122,26 @@ def __call__(
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

# BC for explicit return_tensors
if "return_tensors" in output_kwargs["common_kwargs"]:
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
else:
return_tensors = None
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
encoding = BatchFeature(tensor_type=return_tensors)
if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

text_encoding = {}

return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors

# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
# because BLIP expects image tokens to be at the beginning even before BOS token
if self.num_query_tokens is not None:
image_tokens = self.image_token.content * self.num_query_tokens
image_token_encoding = self.tokenizer(
[image_tokens] * len(text), add_special_tokens=False, return_tensors=None
)
for k in _text_encoding:
text_encoding[k] = [
img_encoding + txt_encoding
for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
]
else:
text_encoding = _text_encoding
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
# We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token
image_tokens = self.image_token.content * self.num_query_tokens
text = [f"{image_tokens}{self.bos_token}{sample}" for sample in text]
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])

# cast to desired return tensors type
encoding.update(BatchEncoding(text_encoding, tensor_type=return_tensors))

# add pixel_values encoding. If we also have text_encoding, update image encoding and return it.
# else, return the text encoding.

if images is not None:
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
encoding.update(image_encoding)
Expand Down
48 changes: 4 additions & 44 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,29 +1454,13 @@ def forward(

# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)

inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()

if self.config.use_decoder_only_language_model:
outputs = self.language_model(
Expand Down Expand Up @@ -1586,9 +1570,6 @@ def generate(
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]

language_model_inputs = self.language_projection(query_output)
language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)

if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
Expand All @@ -1602,29 +1583,8 @@ def generate(

inputs_embeds = self.get_input_embeddings()(input_ids)

# if the model already has "image_token_index" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating
if getattr(self.config, "image_token_index", None) is not None:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
else:
logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
attention_mask = torch.cat(
[language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
)

# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = (
generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
)
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()

inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
Expand Down
37 changes: 8 additions & 29 deletions src/transformers/models/instructblip/processing_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import (
AddedToken,
BatchEncoding,
PreTokenizedInput,
TextInput,
)
Expand Down Expand Up @@ -84,6 +83,10 @@ def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_toke
else:
self.image_token = tokenizer.image_token
self.num_query_tokens = num_query_tokens

# We'll add the BOS manually as it has to be after image tokens
tokenizer.add_bos_token = False
self.bos_token = tokenizer.bos_token
super().__init__(image_processor, tokenizer, qformer_tokenizer)

def __call__(
Expand Down Expand Up @@ -125,34 +128,10 @@ def __call__(
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

# we have to concatenate lists - so we keep track of return_tensors here
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
# because BLIP expects image tokens to be at the beginning even before BOS token
if self.num_query_tokens is not None and images is not None:
text_encoding = {}
image_tokens = self.image_token.content * self.num_query_tokens
image_token_encoding = self.tokenizer(
[image_tokens] * len(text), add_special_tokens=False, return_tensors=None
)
for k in _text_encoding:
text_encoding[k] = [
img_encoding + txt_encoding
for img_encoding, txt_encoding in zip(image_token_encoding[k], _text_encoding[k])
]
else:
text_encoding = _text_encoding
if images is not None:
logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
)

# cast to desired return tensors type after concatenating
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
# We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token
image_tokens = self.image_token.content * self.num_query_tokens
text_with_images = [f"{image_tokens}{self.bos_token}{sample}" for sample in text]
text_encoding = self.tokenizer(text_with_images, **output_kwargs["text_kwargs"])

encoding.update(text_encoding)
qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"])
Expand Down
Loading