From 6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 10 Jan 2025 12:23:00 +0100 Subject: [PATCH] [WIP] Emu3: add model (#33770) * model can convert to HF and be loaded back * nit * works in single batch generation but hallucinates * use the image tokens * add image generation * now it works * add tests * update * add modulare but it doesn't work for porting docstring :( * skip some tests * add slow tests * modular removed the import? * guess this works * update * update * fix copies * fix test * fix copies * update * docs * fix tests * last fix tests? * pls * repo consistency * more style * style * remove file * address comments * tiny bits * update after the new modular * fix tests * add one more cond in check attributes * decompose down/up/mid blocks * allow static cache generation in VLMs * nit * fix copies * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/model_doc/emu3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix VAE upsampling * Update src/transformers/models/emu3/modular_emu3.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address comments * state overwritten stuff explicitly * fix copies * add the flag for flex attn --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/emu3.md | 179 ++ docs/source/en/perf_infer_gpu_one.md | 2 + src/transformers/__init__.py | 30 + src/transformers/generation/utils.py | 7 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/chameleon/processing_chameleon.py | 1 + src/transformers/models/emu3/__init__.py | 29 + .../models/emu3/configuration_emu3.py | 327 +++ .../models/emu3/convert_emu3_weights_to_hf.py | 448 ++++ .../models/emu3/image_processing_emu3.py | 552 +++++ src/transformers/models/emu3/modeling_emu3.py | 1949 +++++++++++++++++ src/transformers/models/emu3/modular_emu3.py | 1270 +++++++++++ .../models/emu3/processing_emu3.py | 217 ++ src/transformers/utils/dummy_pt_objects.py | 35 + .../utils/dummy_vision_objects.py | 7 + tests/generation/test_utils.py | 18 +- tests/models/emu3/__init__.py | 0 tests/models/emu3/test_modeling_emu3.py | 550 +++++ tests/models/emu3/test_processor_emu3.py | 85 + tests/test_modeling_common.py | 2 +- utils/check_config_attributes.py | 4 + utils/check_repo.py | 4 + 28 files changed, 5722 insertions(+), 5 deletions(-) create mode 100644 docs/source/en/model_doc/emu3.md create mode 100644 src/transformers/models/emu3/__init__.py create mode 100644 src/transformers/models/emu3/configuration_emu3.py create mode 100644 src/transformers/models/emu3/convert_emu3_weights_to_hf.py create mode 100644 src/transformers/models/emu3/image_processing_emu3.py create mode 100644 src/transformers/models/emu3/modeling_emu3.py create mode 100644 src/transformers/models/emu3/modular_emu3.py create mode 100644 src/transformers/models/emu3/processing_emu3.py create mode 100644 tests/models/emu3/__init__.py create mode 100644 tests/models/emu3/test_modeling_emu3.py create mode 100644 tests/models/emu3/test_processor_emu3.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e9b0d465ab800f..529b113cf1e578 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -860,6 +860,8 @@ title: DePlot - local: model_doc/donut title: Donut + - local: model_doc/emu3 + title: Emu3 - local: model_doc/flava title: FLAVA - local: model_doc/git diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 84db43f825eba4..d66f4b031ac86b 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -137,6 +137,7 @@ Flax), PyTorch, and/or TensorFlow. | [EfficientFormer](model_doc/efficientformer) | ✅ | ✅ | ❌ | | [EfficientNet](model_doc/efficientnet) | ✅ | ❌ | ❌ | | [ELECTRA](model_doc/electra) | ✅ | ✅ | ✅ | +| [Emu3](model_doc/emu3) | ✅ | ❌ | ❌ | | [EnCodec](model_doc/encodec) | ✅ | ❌ | ❌ | | [Encoder decoder](model_doc/encoder-decoder) | ✅ | ✅ | ✅ | | [ERNIE](model_doc/ernie) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/emu3.md b/docs/source/en/model_doc/emu3.md new file mode 100644 index 00000000000000..0b3220c073fb65 --- /dev/null +++ b/docs/source/en/model_doc/emu3.md @@ -0,0 +1,179 @@ + + +# Emu3 + +## Overview + +The Emu3 model was proposed in [Emu3: Next-Token Prediction is All You Need](https://arxiv.org/abs/2409.18869) by Xinlong Wang, Xiaosong Zhang, Zhengxiong Luo, Quan Sun, Yufeng Cui, Jinsheng Wang, Fan Zhang, Yueze Wang, Zhen Li, Qiying Yu, Yingli Zhao, Yulong Ao, Xuebin Min, Tao Li, Boya Wu, Bo Zhao, Bowen Zhang, Liangdong Wang, Guang Liu, Zheqi He, Xi Yang, Jingjing Liu, Yonghua Lin, Tiejun Huang, Zhongyuan Wang. + +Emu3 is a multimodal LLM that uses vector quantization to tokenize images into discrete tokens. Discretized image tokens are later fused with text token ids for image and text generation. The model can additionally generate images by predicting image token ids. + + +The abstract from the paper is the following: + +*While next-token prediction is considered a promising path towards artificial general intelligence, it has struggled to excel in multimodal tasks, which are still dominated by diffusion models (e.g., Stable Diffusion) and compositional approaches (e.g., CLIP combined with LLMs). In this paper, we introduce Emu3, a new suite of state-of-the-art multimodal models trained solely with next-token prediction. By tokenizing images, text, and videos into a discrete space, we train a single transformer from scratch on a mixture of multimodal sequences. Emu3 outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship models such as SDXL and LLaVA-1.6, while eliminating the need for diffusion or compositional architectures. Emu3 is also capable of generating high-fidelity video via predicting the next token in a video sequence. We simplify complex multimodal model designs by converging on a singular focus: tokens, unlocking great potential for scaling both during training and inference. Our results demonstrate that next-token prediction is a promising path towards building general multimodal intelligence beyond language. We open-source key techniques and models to support further research in this direction.* + +Tips: + +- We advise users to set `processor.tokenizer.padding_side = "left"` before batched generation as it leads to more accurate results. + +- Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts. + +- Emu3 has two different checkpoints for image-generation and text-generation, make sure to use the correct checkpoint when loading the model. To generate an image, it is advised to use `prefix_constraints` so that the generated tokens are sampled only from possible image tokens. See more below for usage examples. + +> [!TIP] +> Emu3 implementation in Transformers uses a special image token to indicate where to merge image embeddings. The special image token isn't new and uses one of the reserved tokens: `<|extra_0|>`. You have to add `` to your prompt in the place where the image should be embedded for correct generation. + + +This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay). +The original code can be found [here](https://github.com/baaivision/Emu3). + + +## Usage example + +### Text generation inference + +Here's how to load the model and perform inference in half-precision (`torch.bfloat16`) to generate textual output from text or text and image inputs: + +```python +from transformers import Emu3Processor, Emu3ForConditionalGeneration +import torch +from PIL import Image +import requests + +processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") +model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Chat-hf", torch_dtype=torch.bfloat16, device_map="cuda") + +# prepare image and text prompt +url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +image = Image.open(requests.get(url, stream=True).raw) +prompt = "What do you see in this image?" + +inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, dtype=torch.bfloat16) + +# autoregressively complete prompt +output = model.generate(**inputs, max_new_tokens=50) +print(processor.decode(output[0], skip_special_tokens=True)) +``` + +### Image generation inference + +Emu3 can also generate images from textual input. Here is how you can do it: + +```python +processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Gen-hf") +model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Gen-hf", torch_dtype="bfloat16", device_map="auto", attn_implementation="flash_attention_2") + + +inputs = processor( + text=["a portrait of young girl. masterpiece, film grained, best quality.", "a dog running under the rain"], + padding=True, + return_tensors="pt", + return_for_image_generation=True, +) +inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16) + +neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry." +neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0") + +image_sizes = inputs.pop("image_sizes") +HEIGHT, WIDTH = image_sizes[0] +VISUAL_TOKENS = model.vocabulary_mapping.image_tokens + +def prefix_allowed_tokens_fn(batch_id, input_ids): + height, width = HEIGHT, WIDTH + visual_tokens = VISUAL_TOKENS + image_wrapper_token_id = torch.tensor([processor.tokenizer.image_wrapper_token_id], device=model.device) + eoi_token_id = torch.tensor([processor.tokenizer.eoi_token_id], device=model.device) + eos_token_id = torch.tensor([processor.tokenizer.eos_token_id], device=model.device) + pad_token_id = torch.tensor([processor.tokenizer.pad_token_id], device=model.device) + eof_token_id = torch.tensor([processor.tokenizer.eof_token_id], device=model.device) + eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0] + + position = torch.nonzero(input_ids == image_wrapper_token_id, as_tuple=True)[0][0] + offset = input_ids.shape[0] - position + if offset % (width + 1) == 0: + return (eol_token_id, ) + elif offset == (width + 1) * height + 1: + return (eof_token_id, ) + elif offset == (width + 1) * height + 2: + return (eoi_token_id, ) + elif offset == (width + 1) * height + 3: + return (eos_token_id, ) + elif offset > (width + 1) * height + 3: + return (pad_token_id, ) + else: + return visual_tokens + + +out = model.generate( + **inputs, + max_new_tokens=50_000, # make sure to have enough tokens for one image + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + return_dict_in_generate=True, + negative_prompt_ids=neg_inputs.input_ids, # indicate for Classifier-Free Guidance + negative_prompt_attention_mask=neg_inputs.attention_mask, +) + +image = model.decode_image_tokens(out.sequences[:, inputs.input_ids.shape[1]: ], height=HEIGHT, width=WIDTH) +images = processor.postprocess(list(image.float()), return_tensors="PIL.Image.Image") # internally we convert to np but it's not supported in bf16 precision +for i, image in enumerate(images['pixel_values']): + image.save(f"result{i}.png") + +``` + + +## Emu3Config + +[[autodoc]] Emu3Config + +## Emu3VQVAEConfig + +[[autodoc]] Emu3VQVAEConfig + +## Emu3TextConfig + +[[autodoc]] Emu3TextConfig + +## Emu3Processor + +[[autodoc]] Emu3Processor + +## Emu3ImageProcessor + +[[autodoc]] Emu3ImageProcessor + - preprocess + +## Emu3VQVAE + +[[autodoc]] Emu3VQVAE + - forward + +## Emu3TextModel + +[[autodoc]] Emu3TextModel + - forward + +## Emu3ForCausalLM + +[[autodoc]] Emu3ForCausalLM + - forward + +## Emu3ForConditionalGeneration + +[[autodoc]] Emu3ForConditionalGeneration + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 8cfddb45dba329..3e6d764617de2d 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -49,6 +49,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) +* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) @@ -245,6 +246,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel) +* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d72b55cfe37bf5..de34fea356dbfc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -428,6 +428,12 @@ "ElectraConfig", "ElectraTokenizer", ], + "models.emu3": [ + "Emu3Config", + "Emu3Processor", + "Emu3TextConfig", + "Emu3VQVAEConfig", + ], "models.encodec": [ "EncodecConfig", "EncodecFeatureExtractor", @@ -1222,6 +1228,7 @@ _import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"]) _import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"]) _import_structure["models.efficientnet"].append("EfficientNetImageProcessor") + _import_structure["models.emu3"].append("Emu3ImageProcessor") _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"]) _import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"]) _import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"]) @@ -2243,6 +2250,15 @@ "load_tf_weights_in_electra", ] ) + _import_structure["models.emu3"].extend( + [ + "Emu3ForCausalLM", + "Emu3ForConditionalGeneration", + "Emu3PreTrainedModel", + "Emu3TextModel", + "Emu3VQVAE", + ] + ) _import_structure["models.encodec"].extend( [ "EncodecModel", @@ -5440,6 +5456,12 @@ ElectraConfig, ElectraTokenizer, ) + from .models.emu3 import ( + Emu3Config, + Emu3Processor, + Emu3TextConfig, + Emu3VQVAEConfig, + ) from .models.encodec import ( EncodecConfig, EncodecFeatureExtractor, @@ -6270,6 +6292,7 @@ from .models.donut import DonutFeatureExtractor, DonutImageProcessor from .models.dpt import DPTFeatureExtractor, DPTImageProcessor from .models.efficientnet import EfficientNetImageProcessor + from .models.emu3 import Emu3ImageProcessor from .models.flava import ( FlavaFeatureExtractor, FlavaImageProcessor, @@ -7139,6 +7162,13 @@ ElectraPreTrainedModel, load_tf_weights_in_electra, ) + from .models.emu3 import ( + Emu3ForCausalLM, + Emu3ForConditionalGeneration, + Emu3PreTrainedModel, + Emu3TextModel, + Emu3VQVAE, + ) from .models.encodec import ( EncodecModel, EncodecPreTrainedModel, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 05627e23de11ff..18cbab60057621 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1634,17 +1634,18 @@ def _get_cache( cache_dtype = self.get_output_embeddings().weight.dtype def get_layer_device_map(execution_device_map: Optional[dict] = None): + num_hidden_layers = self.config.get_text_config().num_hidden_layers if execution_device_map is None: return None elif len(execution_device_map) == 1 and "" in execution_device_map: - return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)} + return {idx: execution_device_map[""] for idx in range(num_hidden_layers)} layer_device_map = {} for layer in execution_device_map: - for idx in range(self.config.num_hidden_layers): + for idx in range(num_hidden_layers): if f".{idx}." in f"{layer}.": layer_device_map[idx] = execution_device_map[layer] break - for idx in range(self.config.num_hidden_layers): + for idx in range(num_hidden_layers): if idx not in layer_device_map: raise RuntimeError(f"layer {idx} has not been mapped to a device.") return layer_device_map diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b91e020c1b739b..7db328f87af1fb 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -86,6 +86,7 @@ dpt, efficientnet, electra, + emu3, encodec, encoder_decoder, ernie, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index febdf5ae271ca0..985fe59582d875 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -103,6 +103,7 @@ ("efficientformer", "EfficientFormerConfig"), ("efficientnet", "EfficientNetConfig"), ("electra", "ElectraConfig"), + ("emu3", "Emu3Config"), ("encodec", "EncodecConfig"), ("encoder-decoder", "EncoderDecoderConfig"), ("ernie", "ErnieConfig"), @@ -420,6 +421,7 @@ ("efficientformer", "EfficientFormer"), ("efficientnet", "EfficientNet"), ("electra", "ELECTRA"), + ("emu3", "Emu3"), ("encodec", "EnCodec"), ("encoder-decoder", "Encoder decoder"), ("ernie", "ERNIE"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4aef84405b8ac4..bf54a6ce97857b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -499,6 +499,7 @@ ("dbrx", "DbrxForCausalLM"), ("diffllama", "DiffLlamaForCausalLM"), ("electra", "ElectraForCausalLM"), + ("emu3", "Emu3ForCausalLM"), ("ernie", "ErnieForCausalLM"), ("falcon", "FalconForCausalLM"), ("falcon_mamba", "FalconMambaForCausalLM"), @@ -800,6 +801,7 @@ ("blip", "BlipForConditionalGeneration"), ("blip-2", "Blip2ForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"), + ("emu3", "Emu3ForConditionalGeneration"), ("fuyu", "FuyuForCausalLM"), ("git", "GitForCausalLM"), ("idefics", "IdeficsForVisionText2Text"), @@ -1428,6 +1430,7 @@ ("deberta-v2", "DebertaV2Model"), ("distilbert", "DistilBertModel"), ("electra", "ElectraModel"), + ("emu3", "Emu3TextModel"), ("flaubert", "FlaubertModel"), ("ibert", "IBertModel"), ("longformer", "LongformerModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 8df4fefeee4615..cf52e73f568aba 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -59,6 +59,7 @@ ("clipseg", "CLIPSegProcessor"), ("clvp", "ClvpProcessor"), ("colpali", "ColPaliProcessor"), + ("emu3", "Emu3Processor"), ("flava", "FlavaProcessor"), ("fuyu", "FuyuProcessor"), ("git", "GitProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 51b8b590d931a0..2e26cea97139d4 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -186,6 +186,7 @@ ), ), ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), + ("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)), ("esm", ("EsmTokenizer", None)), diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 9f4bc2904c861c..99da53c6c612ff 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -62,6 +62,7 @@ class ChameleonProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + valid_kwargs = ["image_seq_length", "image_token"] image_processor_class = "ChameleonImageProcessor" def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): diff --git a/src/transformers/models/emu3/__init__.py b/src/transformers/models/emu3/__init__.py new file mode 100644 index 00000000000000..d8555f58d18664 --- /dev/null +++ b/src/transformers/models/emu3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_emu3 import * + from .image_processing_emu3 import * + from .modeling_emu3 import * + from .processing_emu3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/emu3/configuration_emu3.py b/src/transformers/models/emu3/configuration_emu3.py new file mode 100644 index 00000000000000..5b5abedf4016d5 --- /dev/null +++ b/src/transformers/models/emu3/configuration_emu3.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Emu3VQVAEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Emu3VQVAE`]. It is used to instantiate an VQ-VAE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a configuration to the VQ model presented in Emu3 paper. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + codebook_size (`int`, *optional*, defaults to 32768): + Codebook size of the VQ model. + embed_dim (`int`, *optional*, defaults to 4): + Dimension of the quantized vector in codebook. + latent_channels (`int`, *optional*, defaults to 4): + Dimension of the output channel of encoder and the input channel of decoder + double_latent (`bool`, *optional*, defaults to `False`): + Whether double the output dim of the encoder. + in_channels (`int`, *optional*, defaults to 3): + Input channel of encoder. + out_channels (`int`, *optional*, defaults to 3): + Output channel of decoder. + temporal_downsample_factor (`int`, *optional*, defaults to 4): + Temporal downsample factor. + base_channels (`int`, *optional*, defaults to 256): + Basic channel number of the intermediate blocks. + channel_multiplier (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`): + Channel scaling factor of the intermediate blocks. + num_res_blocks (`int`, *optional*, defaults to 2): + Residual block number in each stage. + attn_resolutions (`List[int]`, *optional*, defaults to `[3]`): + Stage indices to apply attention. + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations in the attention layer. + num_attention_heads (`int`, *optional*, defaults to 1): + Number of attention heads for each attention layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Emu3VQVAE, Emu3VQVAEConfig + + >>> # Initializing a video VQ model of Emu3 configuration + >>> configuration = Emu3VQVAEConfig() + + >>> # Initializing a model from the Emu3 VQ model style configuration + >>> model = Emu3VQVAE(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "emu3_vqgan" + base_config_key = "vq_config" + + def __init__( + self, + codebook_size: int = 32768, + embed_dim: int = 4, + latent_channels: int = 4, + double_latent: bool = False, + in_channels: int = 3, + out_channels: int = 3, + temporal_downsample_factor: int = 4, + base_channels: int = 256, + channel_multiplier: List[int] = [1, 2, 2, 4], + num_res_blocks: int = 2, + attn_resolutions: List[int] = [3], + hidden_size: int = 1024, + num_attention_heads: int = 1, + attention_dropout: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.codebook_size = codebook_size + self.embed_dim = embed_dim + self.latent_channels = latent_channels + self.double_latent = double_latent + self.in_channels = in_channels + self.out_channels = out_channels + self.temporal_downsample_factor = temporal_downsample_factor + self.base_channels = base_channels + self.channel_multiplier = channel_multiplier + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + + +class Emu3TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Emu3TextModel`]. It is used to instantiate a + emu3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 184622): + Vocabulary size of the Emu3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Emu3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 9216): + The maximum sequence length that this model might ever be used with. Emu supports up to 9216 tokens, + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 151643): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 151849): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 151850): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + ```python + >>> from transformers import Emu3Model, Emu3Config + + >>> # Initializing a Emu3-community/Emu3-Chat-hf style configuration + >>> configuration = Emu3Config() + + >>> # Initializing a model from the Emu3-community/Emu3-Chat-hf style configuration + >>> model = Emu3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "emu3_text_model" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 184622, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = 8, + hidden_act: str = "silu", + max_position_embeddings: int = 9216, + rms_norm_eps: float = 1e-5, + use_cache: bool = True, + pad_token_id: int = 151643, + bos_token_id: int = 151849, + eos_token_id: int = 151850, + tie_word_embeddings: bool = False, + rope_theta: float = 1000000.0, + rope_scaling: Optional = None, + mlp_bias=False, + attention_bias=False, + attention_dropout: float = 0.1, + initializer_range: float = 0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.mlp_bias = mlp_bias + self.attention_bias = attention_bias + self.initializer_range = initializer_range + rope_config_validation(self) + + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Emu3Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Emu3Model`]. It is used to instantiate a + emu3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vq_config (`Union[Dict, Emu3VQVAEConfig]`, *optional*): + Emu3VQVAEConfig instance containing the configuration for the VQ-VAE model. + text_config (`Union[Dict, Emu3TextConfig]``, *optional*): + Emu3TextConfig instance containing the configuration for the language model. + vocabulary_map (`dict`, *optional*): + A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. + """ + + model_type = "emu3" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"text_config": Emu3TextConfig, "vq_config": Emu3VQVAEConfig} + + def __init__( + self, + vq_config: Union[Dict, Emu3VQVAEConfig] = None, + text_config: Union[Dict, Emu3TextConfig] = None, + vocabulary_map: Dict[int, int] = None, + **kwargs, + ): + if vq_config is None: + vq_config = Emu3VQVAEConfig() + elif isinstance(vq_config, dict): + vq_config = Emu3VQVAEConfig(**vq_config) + + if text_config is None: + text_config = Emu3TextConfig() + elif isinstance(text_config, dict): + text_config = Emu3TextConfig(**text_config) + + self.vq_config = vq_config + self.text_config = text_config + self.vocabulary_map = vocabulary_map + + super().__init__(**kwargs) + + +__all__ = ["Emu3Config", "Emu3TextConfig", "Emu3VQVAEConfig"] diff --git a/src/transformers/models/emu3/convert_emu3_weights_to_hf.py b/src/transformers/models/emu3/convert_emu3_weights_to_hf.py new file mode 100644 index 00000000000000..8ac8db7e429031 --- /dev/null +++ b/src/transformers/models/emu3/convert_emu3_weights_to_hf.py @@ -0,0 +1,448 @@ +# Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import re +from typing import Dict, Optional + +import requests +import torch +from accelerate import init_empty_weights +from PIL import Image + +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + Emu3Config, + Emu3ForConditionalGeneration, + Emu3ImageProcessor, + Emu3Processor, + Emu3TextConfig, + GenerationConfig, +) +from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + + +""" +Sample usage: + +``` +python src/transformers/models/emu3/convert_emu3_weights_to_hf.py \ + --vq_model_id BAAI/Emu3-VisionTokenizer --llm_model_id BAAI/Emu3-Chat --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import Emu3ForConditionalGeneration, Emu3Processor + +model = Emu3ForConditionalGeneration.from_pretrained("/output/path") +processor = Emu3Processor.from_pretrained("/output/path") +``` + +""" + + +byte_encoder = bytes_to_unicode() +CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" + + +# Tiktoken to HF conversion, thanks for Xenova +def token_bytes_to_string(b): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + +# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960 +def bpe(mergeable_ranks: Dict[bytes, int], token: bytes, max_rank: Optional[int] = None): + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] + return parts + + +def generate_vocab_and_merges(encoder): + mergeable_ranks = encoder._mergeable_ranks + + merges = [] + vocab = {} + for token, rank in mergeable_ranks.items(): + vocab[token_bytes_to_string(token)] = rank + + if len(token) == 1: + continue + merged = tuple(bpe(mergeable_ranks, token, max_rank=rank)) + assert len(merged) == 2 + merges.append(" ".join(map(token_bytes_to_string, merged))) + + # Also add special tokens + vocab.update(encoder._special_tokens) + return vocab, merges + + +def convert_tiktoken(tokenizer, output_dir): + encoder = tokenizer.tokenizer + vocab, merges = generate_vocab_and_merges(encoder) + added_tokens = [ + { + "id": id, + "content": content, + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + "special": True, + } + for content, id in encoder._special_tokens.items() + if content != "<|extra_0|>" + ] + + # https://huggingface.co/Xenova/gpt2/raw/main/tokenizer_config.json + tokenizer_config_template = { + "add_prefix_space": False, + "bos_token": "<|extra_203|>", + "clean_up_tokenization_spaces": False, + "eos_token": "<|extra_204|>", + "pad_token": "<|endoftext|>", + } + tokenizer_config_template.update({"tokenizer_class": "GPT2Tokenizer"}) + tokenizer_config_template = dict(sorted(tokenizer_config_template.items(), key=lambda x: x[0])) + + # add placeholder image token by taking one of the reserved tokens + reserved_token_id = vocab["<|extra_0|>"] + vocab[""] = reserved_token_id + del vocab["<|extra_0|>"] + added_tokens.append( + { + "id": reserved_token_id, + "content": "", + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + "special": True, + } + ) + + os.makedirs(output_dir, exist_ok=True) + + pre_tokenizer = { + "type": "ByteLevel", + "add_prefix_space": False, + "trim_offsets": True, + "use_regex": True, + } + + # https://huggingface.co/Xenova/gpt2/raw/main/tokenizer.json + tokenizer_template = { + "version": "1.0", + "truncation": None, + "padding": None, + "added_tokens": added_tokens, + "normalizer": None, + "pre_tokenizer": pre_tokenizer, + "post_processor": None, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": True, + "trim_offsets": True, + "use_regex": True, + }, + "model": { + "type": "BPE", + "dropout": None, + "unk_token": None, + "continuing_subword_prefix": "", + "end_of_word_suffix": "", + "fuse_unk": False, + "byte_fallback": False, + "vocab": vocab, + "merges": merges, + }, + } + + # Save to files + with open(os.path.join(output_dir, "vocab.json"), "w", encoding="utf-8") as fp: + json.dump(vocab, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "tokenizer.json"), "w", encoding="utf-8") as fp: + json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "tokenizer_config.json"), "w", encoding="utf-8") as fp: + json.dump(tokenizer_config_template, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "special_tokens_map.json"), "w", encoding="utf-8") as fp: + json.dump( + { + "bos_token": "<|extra_203|>", + "eos_token": "<|extra_204|>", + "pad_token": "<|endoftext|>", + }, + fp, + indent=2, + ensure_ascii=False, + ) + + with open(os.path.join(output_dir, "merges.txt"), "w", encoding="utf-8") as fp: + fp.write("#version: 0.2\n") + fp.write("\n".join(merges)) + + +KEYS_TO_MODIFY_MAPPING = { + "^encoder": "model.vqmodel.encoder", + "^decoder": "model.vqmodel.decoder", + "^post_quant_conv": "model.vqmodel.post_quant_conv", + "^quant_conv": "model.vqmodel.quant_conv", + "^quantize": "model.vqmodel.quantize", + "^model": "text_model.model", + r"lm_head\.weight": "text_model.lm_head.weight", + r"^text_model\.model\.vqmodel": "vqmodel", + # rename QKV proj for the VQ-VAE model because we use SiglipAttention + r"\.q\.": ".q_proj.", + r"\.k\.": ".k_proj.", + r"\.v\.": ".v_proj.", + r"\.proj_out\.": ".out_proj.", + # move the attention norms outside of attention modules + r"mid\.attn_1\.norm\.": "mid.attn_norm.", + r"attn\.0\.norm\.": "attn_norms.0.", + r"attn\.1\.norm\.": "attn_norms.1.", + r"attn\.2\.norm\.": "attn_norms.2.", + r"attn\.3\.norm\.": "attn_norms.3.", + # isolate down/mid/up into separate classes for readability + r"\.down\.": ".down_block.down.", + r"\.up\.": ".up_block.up.", + r"\.mid\.": ".middle_block.", +} + + +def convert_state_dict_to_hf(old_state_dict, new_state_dict): + for key, value in old_state_dict.items(): + # convert conv layers in attn to linear + if ( + any(key.endswith(name) for name in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]) + and value.ndim == 4 + ): + value = value.squeeze() + + for old_pattern, new_pattern in KEYS_TO_MODIFY_MAPPING.items(): + key = re.sub(old_pattern, new_pattern, key) + + new_state_dict[key] = value + return new_state_dict + + +def convert_model(vq_model_id, llm_model_id, output_dir, hub_model_id=None, test_inference=False): + os.makedirs(output_dir, exist_ok=True) + + # Convert and save processor + tokenizer_tiktoken = AutoTokenizer.from_pretrained(llm_model_id, trust_remote_code=True) + convert_tiktoken(tokenizer_tiktoken, output_dir) + extra_special_tokens = extra_special_tokens = { + "image_token": "", + "boi_token": "<|image start|>", + "eoi_token": "<|image end|>", + "image_wrapper_token": "<|image token|>", + "eof_token": "<|extra_201|>", + } + tokenizer_converted = AutoTokenizer.from_pretrained(output_dir, extra_special_tokens=extra_special_tokens) + tokenizer_converted.padding_side = "left" + + image_processor = Emu3ImageProcessor.from_pretrained(vq_model_id) + processor = Emu3Processor(image_processor, tokenizer_converted, chat_template=CHAT_TEMPLATE) + processor.save_pretrained(output_dir) + + # load models + model_llm = AutoModelForCausalLM.from_pretrained( + llm_model_id, + trust_remote_code=True, + ) + model_vqgan = AutoModel.from_pretrained(vq_model_id, trust_remote_code=True) + with open(f"{output_dir}/tokenizer.json", "r") as file: + tokenizer_config = json.load(file) + vocabulary_map = tokenizer_config["model"]["vocab"] + + text_config = Emu3TextConfig( + max_position_embeddings=model_llm.config.max_position_embeddings, + rope_scaling={"rope_type": "default"}, + ) + config = Emu3Config(text_config=text_config, vocabulary_map=vocabulary_map) + + with init_empty_weights(): + model = Emu3ForConditionalGeneration(config=config) + model.generation_config = GenerationConfig( + do_sample=True, + top_k=2048, + max_new_tokens=50_000, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, + ) + + state_dict = {} + state_dict = convert_state_dict_to_hf(model_llm.state_dict(), state_dict) + state_dict = convert_state_dict_to_hf(model_vqgan.state_dict(), state_dict) + + model.load_state_dict(state_dict, assign=True, strict=True) + model.save_pretrained(output_dir, safe_serialization=True) + + if hub_model_id is not None: + model.push_to_hub(hub_model_id) + processor.push_to_hub(hub_model_id) + + if test_inference and llm_model_id.endswith("Chat"): + # Short inference on a few examples to check if generation makes sense + print("Loading the checkpoint in a Emu3 model...") + print("*" * 100) + model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto") + processor = Emu3Processor.from_pretrained(output_dir) + + conversation = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Please tell me about this art work and its artist."}, + {"type": "image"}, + ], + }, + ] + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + + image = Image.open( + requests.get( + "https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True + ).raw + ) + inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.bfloat16) + length = inputs.input_ids.shape[1] + + out = model.generate(**inputs, max_new_tokens=40, do_sample=False) + generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0] + + print(f"Generation for single-image: {generated_text}") + print("*" * 100) + elif test_inference and llm_model_id.endswith("Gen"): + processor = Emu3Processor.from_pretrained(output_dir) + model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto") + + inputs = processor( + text=[ + "a portrait of young girl. masterpiece, film grained, best quality.", + "a dog running under the rain", + ], + padding=True, + return_tensors="pt", + return_for_image_generation=True, + ) + inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16) + + neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry." + neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0") + + image_sizes = inputs.pop("image_sizes") + HEIGHT, WIDTH = image_sizes[0] + VISUAL_TOKENS = model.vocabulary_mapping.image_tokens + + def prefix_allowed_tokens_fn(batch_id, input_ids): + height, width = HEIGHT, WIDTH + visual_tokens = VISUAL_TOKENS + image_token_id = processor.tokenizer.encode("<|image token|>", return_tensors="pt")[0].to(model.device) + eoi_token_id = processor.tokenizer.encode("<|image end|>", return_tensors="pt")[0] + eos_token_id = processor.tokenizer.encode("<|extra_204|>", return_tensors="pt")[0] + pad_token_id = processor.tokenizer.encode("<|endoftext|>", return_tensors="pt")[0] + eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0] + eof_token_id = processor.tokenizer.encode("<|extra_201|>", return_tensors="pt")[0] + + position = torch.nonzero(input_ids == image_token_id, as_tuple=True)[0][0] + offset = input_ids.shape[0] - position + if offset % (width + 1) == 0: + return (eol_token_id,) + elif offset == (width + 1) * height + 1: + return (eof_token_id,) + elif offset == (width + 1) * height + 2: + return (eoi_token_id,) + elif offset == (width + 1) * height + 3: + return (eos_token_id,) + elif offset > (width + 1) * height + 3: + return (pad_token_id,) + else: + return visual_tokens + + out = model.generate( + **inputs, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + negative_prompt_ids=neg_inputs.input_ids, + negative_prompt_attention_mask=neg_inputs.attention_mask, + ) + + image = model.decode_image_tokens(out[:, inputs.input_ids.shape[1] :], height=HEIGHT, width=WIDTH) + images = processor.postprocess( + list(image.float()), return_tensors="PIL.Image.Image" + ) # internally we convert to np but it's not supported in bf16 precision + for i, image in enumerate(images["pixel_values"]): + image.save(f"result_{i}.png") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--vq_model_id", + help="Model ID of Emu3 VQ-VAE on the hub", + default="BAAI/Emu3-VisionTokenizer", + ) + parser.add_argument( + "--llm_model_id", + help="Model ID of Emu3 bacbone LLM on the hub", + default="BAAI/Emu3-Chat", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model", + ) + parser.add_argument( + "--hub_model_id", + help="Model ID in the hub where to push the model.", + ) + parser.add_argument( + "--test_inference", + action="store_true", + help="Whether to load the model for generation to test it's converted correctly.", + ) + args = parser.parse_args() + convert_model( + vq_model_id=args.vq_model_id, + llm_model_id=args.llm_model_id, + output_dir=args.output_dir, + hub_model_id=args.hub_model_id, + test_inference=args.test_inference, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/emu3/image_processing_emu3.py b/src/transformers/models/emu3/image_processing_emu3.py new file mode 100644 index 00000000000000..f28bc501ba169c --- /dev/null +++ b/src/transformers/models/emu3/image_processing_emu3.py @@ -0,0 +1,552 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Iterable, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + VideoInput, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +if is_vision_available(): + from PIL import Image + +logger = logging.get_logger(__name__) + + +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_valid_image(images): + return [images] + + raise ValueError(f"Could not make batched images from {images}") + + +def smart_resize( + height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class Emu3ImageProcessor(BaseImageProcessor): + r""" + Constructs a Emu3 image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + min_pixels (`int`, *optional*, defaults to `512 * 512`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `1024 * 1024`): + The max pixels of the image to resize the image. + spatial_factor (`int`, *optional*, defaults to 8): + The spatial downsample factor the image will be downsampled in feature extracting phase + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + do_pad: bool = True, + min_pixels: int = 512 * 512, + max_pixels: int = 1024 * 1024, + spatial_factor: int = 8, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.spatial_factor = spatial_factor + self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: Union[ImageInput, VideoInput], + do_resize: bool = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + vision_info (`List[Dict]`, *optional*): + Optional list of dictionaries containing additional information about vision inputs. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.spatial_factor, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + images = np.array(processed_images) + return images + + def _pad_for_batching( + self, + pixel_values: List[np.ndarray], + image_sizes: List[List[int]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + + Args: + pixel_values (`List[np.ndarray]`): + An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`) + image_sizes (`List[List[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + List[`np.ndarray`]: The padded images. + """ + + max_shape = ( + max([size[0] for size in image_sizes]), + max([size[1] for size in image_sizes]), + ) + pixel_values = [ + pad( + image, + padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])), + data_format=data_format, + input_data_format=input_data_format, + ) + for image, size in zip(pixel_values, image_sizes) + ] + return pixel_values + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + do_pad: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + + if images is not None: + images = make_batched_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + pixel_values = [] + for image in images: + image = self._preprocess( + image, + do_resize=do_resize, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(image) + + image_sizes = [image.shape[-2:] for image in pixel_values] + if do_pad: + pixel_values = self._pad_for_batching(pixel_values, image_sizes) + pixel_values = np.array(pixel_values) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_sizes": image_sizes}, tensor_type=return_tensors + ) + + def postprocess( + self, + images: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Union[str, TensorType] = "PIL.Image.Image", + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess. + The parameters should be same as in preprocess. + Args: + images (`ImageInput`): + Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + if isinstance(images[0], Image.Image): + return images if len(images) > 1 else images[0] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_values = [] + for image in images: + image = to_numpy_array(image) + if do_normalize: + image = self.unnormalize( + image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + image = image.clip(0, 255).astype(np.uint8) + + if do_normalize and do_rescale and return_tensors == "PIL.Image.Image": + image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format) + pixel_values.append(Image.fromarray(image)) + else: + pixel_values.extend(image) + + data = {"pixel_values": pixel_values} + return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None + + return BatchFeature(data=data, tensor_type=return_tensors) + + def unnormalize( + self, + image: np.array, + image_mean: Union[float, Iterable[float]], + image_std: Union[float, Iterable[float]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.array: + """ + Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`. + image = (image * image_std) + image_mean + Args: + image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`): + Batch of pixel values to postprocess. + image_mean (`float` or `Iterable[float]`): + The mean to use for unnormalization. + image_std (`float` or `Iterable[float]`): + The standard deviation to use for unnormalization. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + num_channels = 3 + + if isinstance(image_mean, Iterable): + if len(image_mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}") + else: + image_mean = [image_mean] * num_channels + + if isinstance(image_std, Iterable): + if len(image_std) != num_channels: + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}") + else: + image_std = [image_std] * num_channels + + rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std)) + rev_image_std = tuple(1 / std for std in image_std) + image = self.normalize( + image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format + ) + return image + + +__all__ = ["Emu3ImageProcessor"] diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py new file mode 100644 index 00000000000000..1ee883aa406d64 --- /dev/null +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -0,0 +1,1949 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/emu3/modular_emu3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_emu3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "Emu3Config" + + +class Emu3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Emu3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Emu3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Emu3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Emu3DecoderLayer(nn.Module): + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Emu3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Emu3MLP(config) + self.input_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dropout = nn.Dropout(config.attention_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.dropout(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Emu3VQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) + + def forward(self, hidden_state: torch.Tensor): + batch_size, temporal, channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous() + hidden_state_flattened = hidden_state.view(-1, channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + embedding_sum = torch.sum(self.embedding.weight**2, dim=1) + + # "bd,dn->bn", + distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + distances = hidden_state_sum + embedding_sum - distances + + min_encoding_indices = torch.argmin(distances, dim=1) + min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width) + return min_encoding_indices + + +class Emu3VQVAEEncoderConvDownsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, hidden_states): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEEncoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEConv3d(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: Tuple[int], + stride: Tuple[int], + ): + super().__init__() + + padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])] + self.padding = () + for pad_size in padding_sizes[::-1]: + self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2) + self.padding += (2, 0) + + self.conv = nn.Conv3d( + in_channel, + out_channel, + kernel_size, + stride=stride, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = F.pad(hidden_states, self.padding) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAESpatialNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm( + num_channels=out_channels, + num_groups=32, + eps=1e-6, + affine=True, + ) + + self.conv_y = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.conv_b = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest") + hidden_states = self.norm_layer(hidden_states) + hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states) + return hidden_states + + +class Emu3VQVAETemporalUpsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + batch_size, channels, temporal, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal) + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous() + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalDownsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(4, 3, 3), + stride=(2, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.BatchNorm3d(in_channels) + self.conv1 = Emu3VQVAEConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + self.norm2 = nn.BatchNorm3d(out_channels) + self.conv2 = Emu3VQVAEConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + quant_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.quant_channels = quant_channels + + if quant_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels) + self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None): + norm_args = () if self.quant_channels is None else (quant_channels,) + + residual = hidden_states + hidden_states = self.norm1(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEAttentionBlock(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Emu3VQVAEGroupNorm(nn.GroupNorm): + """ + Same as the torch GroupNorm with the only difference that this ones accepts + an optional kwarg `quant_states` which is not used. This class makes it easier to + use SpatialNorm or GroupNorm without conditionals + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, input, quant_states=None): + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + +class Emu3VQVAEMiddleBlock(nn.Module): + def __init__(self, config, in_channels, quant_channels=None): + super().__init__() + + self.block_1 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + self.attn_1 = Emu3VQVAEAttentionBlock(config) + if quant_channels is None: + self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels) + + self.block_2 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor = None): + hidden_states = self.block_1(hidden_states, quant_states) + residual = hidden_states + hidden_states = self.attn_norm(hidden_states, quant_states) + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = self.attn_1(hidden_states)[0] + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + hidden_states = self.block_2(hidden_states, quant_states) + return hidden_states + + +class Emu3VQVAEDownBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + channel_multiplier = config.channel_multiplier + + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if config.attn_resolutions is not None and i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)) + + down = nn.Module() + down.block = block + down.attn = attn + down.attn_norms = attn_norms + if i_level != self.num_resolutions - 1: + down.downsample = Emu3VQVAEEncoderConvDownsample(block_in) + self.down.append(down) + + def forward(self, hidden_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.down): + for i_block in range(self.num_res_blocks): + hidden_states = blocks.block[i_block](hidden_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + + if i_level != self.num_resolutions - 1: + hidden_states = blocks.downsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEUpBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_out = config.base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + quant_channels=quant_channels, + ) + ) + block_in = block_out + if i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in)) + + up = nn.Module() + up.block = block + up.attn = attn + up.attn_norms = attn_norms + if i_level != 0: + up.upsample = Emu3VQVAEEncoderConvUpsample(block_in) + + self.up.insert(0, up) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.up[::-1]): + for i_block in range(self.num_res_blocks + 1): + hidden_states = blocks.block[i_block](hidden_states, quant_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + if i_level != len(self.up) - 1: + hidden_states = blocks.upsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + base_channels = config.base_channels + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + out_channels = 2 * latent_channels if double_latent else latent_channels + block_in = base_channels * channel_multiplier[-1] + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.down_block = Emu3VQVAEDownBlock(config) + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + self.time_res_stack = nn.ModuleList() + + for i in range(temporal_down_blocks): + conv = Emu3VQVAETemporalDownsample(out_channels, out_channels) + self.time_conv.append(conv) + + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + ) + self.time_res_stack.append(time_res_conv) + + def forward(self, pixel_values: torch.LongTensor): + temporal_dim = pixel_values.shape[1] + pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) + + # downsampling & middle + hidden_states = self.conv_in(pixel_values) + hidden_states = self.down_block(hidden_states) + hidden_states = self.middle_block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for conv in self.time_conv: + hidden_states = conv(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + + for layer in self.time_res_stack: + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + return hidden_states + + +class Emu3VQVAEDecoder(nn.Module): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + self.time_res_stack = nn.ModuleList() + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=config.latent_channels, out_channels=config.latent_channels + ) + self.time_res_stack.append(time_res_conv) + + temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + for i in range(temp_upsample_block_num): + conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels) + self.time_conv.append(conv) + + self.conv_in = nn.Conv2d( + config.latent_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels) + self.up_block = Emu3VQVAEUpBlock(config) + + block_in = config.base_channels * config.channel_multiplier[0] + self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in) + self.conv_out = nn.Conv2d( + block_in, + config.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0) + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for layer in self.time_res_stack: + hidden_quant_states = layer(hidden_quant_states) + + for layer in self.time_conv: + hidden_quant_states = layer(hidden_quant_states) + hidden_quant_states *= torch.sigmoid(hidden_quant_states) + + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0) + hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:]) + quant_states = quant_states.reshape(-1, *quant_states.shape[2:]) + + hidden_states = self.conv_in(hidden_states) + + # middle & upsampling + hidden_states = self.middle_block(hidden_states, quant_states) + hidden_states = self.up_block(hidden_states, quant_states) + + hidden_states = self.norm_out(hidden_states, quant_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +EMU3_VQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Emu3VQVAEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). + """, + EMU3_VQ_START_DOCSTRING, +) +class Emu3VQVAE(PreTrainedModel): + config_class = Emu3VQVAEConfig + base_model_prefix = "emuvideovq" + main_input_name = "pixel_values" + _no_split_modules = [ + "Emu3VQVAETemporalResnetBlock", + "Emu3VQVAEAttentionBlock", + "Emu3VQVAEResnetBlock", + "Emu3VQVAEVectorQuantizer", + ] + + def _init_weights(self, module): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + self.config = config + + self.encoder = Emu3VQVAEEncoder(config) + self.decoder = Emu3VQVAEDecoder(config) + self.quantize = Emu3VQVAEVectorQuantizer(config) + self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1) + + self.quant_conv = Emu3VQVAEConv3d( + config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.post_quant_conv = Emu3VQVAEConv3d( + config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1) + self.eval() # Emu3's VQ model is frozen + + self.post_init() + + def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor): + is_image = pixel_values.ndim == 4 + if is_image: + temporal = self.config.temporal_downsample_factor + batch_size, channels, height, width = pixel_values.shape + pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1) + else: + batch_size, temporal, channels, height, width = pixel_values.shape + + hidden_states = self.encoder(pixel_values) + + # b t c h w -> b c t h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + hidden_states = self.quant_conv(hidden_states) + + # b c t h w -> b t c h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + codes = self.quantize(hidden_states) + + image_tokens = codes.squeeze(1) if is_image else codes + + image_tokens = [ + single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)] + for single_image, size in zip(image_tokens, image_sizes) + ] + + return image_tokens + + def decode(self, hidden_states: torch.Tensor): + is_image = hidden_states.ndim == 3 + if is_image: + hidden_states = hidden_states.unsqueeze(1) + + batch_size, temporal, height, width = hidden_states.shape + quant = self.quantize.embedding(hidden_states.flatten()) + + channels = quant.shape[-1] + quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous() + post_quant = self.post_quant_conv(quant) + + quant = quant.permute(0, 2, 1, 3, 4) + post_quant = post_quant.permute(0, 2, 1, 3, 4) + + video = self.decoder(post_quant, quant) + video = video.reshape( + batch_size, + temporal * self.config.temporal_downsample_factor, + self.config.out_channels, + height * self.spatial_scale_factor, + width * self.spatial_scale_factor, + ) + return video[:, 0] if is_image else video + + +class Emu3ImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.eol_token_id = vocab_map.get("<|extra_200|>") + self.image_token_id = vocab_map.get("") + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def image_tokens_str(self): + return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def img2bpe(self): + return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str} + + @cached_property + def bpe2img(self): + return {v: k for k, v in self.img2bpe.items()} + + @cached_property + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: List[torch.Tensor]) -> torch.Tensor: + device = img_batch.device + eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + img_tokens = torch.cat([img_tokens, eol_row], dim=-1) + return img_tokens.to(device) + + def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_batch = img_batch[..., :-1] # remove last row of EOL tokens + img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +EMU3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Emu3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare emu3 Model outputting raw hidden-states without any specific head on top.", + EMU3_START_DOCSTRING, +) +class Emu3PreTrainedModel(PreTrainedModel): + config_class = Emu3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "Emu3DecoderLayer", + ] + _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_cache_class = True + _supports_static_cache = True + _supports_param_buffer_assignment = False + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.get_text_config().initializer_range + if isinstance(module, Emu3VQVAE): + module.apply(module._init_weights) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Emu3RotaryEmbedding(nn.Module): + def __init__(self, config: Emu3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +EMU3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Emu3Text Model outputting raw hidden-states without any specific head on top.", + EMU3_START_DOCSTRING, +) +class Emu3TextModel(Emu3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Emu3TextDecoderLayer`] + + Args: + config: Emu3TextConfig + """ + + def __init__(self, config: Emu3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Emu3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +EMU3_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Has to be an instance of [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + config_class = Emu3TextConfig + + def __init__(self, config): + super().__init__(config) + self.model = Emu3TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForCausalLM.from_pretrained("Emu3-community/Emu3-Chat-hf", torch_dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") + + >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.text_model = Emu3ForCausalLM._from_config(config.text_config) + self.vqmodel = Emu3VQVAE(config.vq_config) + self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. + """ + image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes) + bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list] + bpe_tokens = torch.cat(bpe_tokens_list) + return bpe_tokens + + @torch.no_grad + def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): + """ + Decodes generated image tokens from language model to continuous pixel values + with VQGAN module via upsampling. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): + The tensors corresponding to the input images. + height (`int`): + Height of the generated image before upsampling. + width (`int`): + Width of the generated image before upsampling. + """ + sequences = image_tokens[:, :-3].view(-1, height, width + 1) + image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences) + image = self.vqmodel.decode(image_tokens) + return image + + @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Chat-hf", torch_dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") + + >>> conversation = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."}, + ... ], + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "Please describe the image."}, + ... ], + ... }, + ... ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) + + >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None: + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs + + +__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"] diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py new file mode 100644 index 00000000000000..e9b80d5cbb4deb --- /dev/null +++ b/src/transformers/models/emu3/modular_emu3.py @@ -0,0 +1,1270 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_outputs import ( + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) +from ..chameleon.modeling_chameleon import ( + ChameleonPreTrainedModel, + ChameleonVQVAEEncoderConvDownsample, +) +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, +) +from ..siglip.modeling_siglip import SiglipAttention +from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig + + +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +_CONFIG_FOR_DOC = "Emu3Config" +_CHECKPOINT_FOR_DOC = "Emu3-community/Emu3-Chat-hf" + +logger = logging.get_logger(__name__) + + +# Has extra dropout which no other model in the library has +class Emu3DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.dropout = nn.Dropout(config.attention_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.dropout(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Emu3VQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) + + def forward(self, hidden_state: torch.Tensor): + batch_size, temporal, channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous() + hidden_state_flattened = hidden_state.view(-1, channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + embedding_sum = torch.sum(self.embedding.weight**2, dim=1) + + # "bd,dn->bn", + distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + distances = hidden_state_sum + embedding_sum - distances + + min_encoding_indices = torch.argmin(distances, dim=1) + min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width) + return min_encoding_indices + + +class Emu3VQVAEEncoderConvDownsample(ChameleonVQVAEEncoderConvDownsample): + pass + + +class Emu3VQVAEEncoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEConv3d(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: Tuple[int], + stride: Tuple[int], + ): + super().__init__() + + padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])] + self.padding = () + for pad_size in padding_sizes[::-1]: + self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2) + self.padding += (2, 0) + + self.conv = nn.Conv3d( + in_channel, + out_channel, + kernel_size, + stride=stride, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = F.pad(hidden_states, self.padding) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAESpatialNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm( + num_channels=out_channels, + num_groups=32, + eps=1e-6, + affine=True, + ) + + self.conv_y = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.conv_b = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest") + hidden_states = self.norm_layer(hidden_states) + hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states) + return hidden_states + + +class Emu3VQVAETemporalUpsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + batch_size, channels, temporal, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal) + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous() + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalDownsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(4, 3, 3), + stride=(2, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.BatchNorm3d(in_channels) + self.conv1 = Emu3VQVAEConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + self.norm2 = nn.BatchNorm3d(out_channels) + self.conv2 = Emu3VQVAEConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + quant_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.quant_channels = quant_channels + + if quant_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels) + self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None): + norm_args = () if self.quant_channels is None else (quant_channels,) + + residual = hidden_states + hidden_states = self.norm1(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEAttentionBlock(SiglipAttention): + pass + + +class Emu3VQVAEGroupNorm(nn.GroupNorm): + """ + Same as the torch GroupNorm with the only difference that this ones accepts + an optional kwarg `quant_states` which is not used. This class makes it easier to + use SpatialNorm or GroupNorm without conditionals + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, input, quant_states=None): + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + +class Emu3VQVAEMiddleBlock(nn.Module): + def __init__(self, config, in_channels, quant_channels=None): + super().__init__() + + self.block_1 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + self.attn_1 = Emu3VQVAEAttentionBlock(config) + if quant_channels is None: + self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels) + + self.block_2 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor = None): + hidden_states = self.block_1(hidden_states, quant_states) + residual = hidden_states + hidden_states = self.attn_norm(hidden_states, quant_states) + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = self.attn_1(hidden_states)[0] + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + hidden_states = self.block_2(hidden_states, quant_states) + return hidden_states + + +class Emu3VQVAEDownBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + channel_multiplier = config.channel_multiplier + + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if config.attn_resolutions is not None and i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)) + + down = nn.Module() + down.block = block + down.attn = attn + down.attn_norms = attn_norms + if i_level != self.num_resolutions - 1: + down.downsample = Emu3VQVAEEncoderConvDownsample(block_in) + self.down.append(down) + + def forward(self, hidden_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.down): + for i_block in range(self.num_res_blocks): + hidden_states = blocks.block[i_block](hidden_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + + if i_level != self.num_resolutions - 1: + hidden_states = blocks.downsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEUpBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_out = config.base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + quant_channels=quant_channels, + ) + ) + block_in = block_out + if i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in)) + + up = nn.Module() + up.block = block + up.attn = attn + up.attn_norms = attn_norms + if i_level != 0: + up.upsample = Emu3VQVAEEncoderConvUpsample(block_in) + + self.up.insert(0, up) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.up[::-1]): + for i_block in range(self.num_res_blocks + 1): + hidden_states = blocks.block[i_block](hidden_states, quant_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + if i_level != len(self.up) - 1: + hidden_states = blocks.upsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + base_channels = config.base_channels + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + out_channels = 2 * latent_channels if double_latent else latent_channels + block_in = base_channels * channel_multiplier[-1] + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.down_block = Emu3VQVAEDownBlock(config) + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + self.time_res_stack = nn.ModuleList() + + for i in range(temporal_down_blocks): + conv = Emu3VQVAETemporalDownsample(out_channels, out_channels) + self.time_conv.append(conv) + + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + ) + self.time_res_stack.append(time_res_conv) + + def forward(self, pixel_values: torch.LongTensor): + temporal_dim = pixel_values.shape[1] + pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) + + # downsampling & middle + hidden_states = self.conv_in(pixel_values) + hidden_states = self.down_block(hidden_states) + hidden_states = self.middle_block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for conv in self.time_conv: + hidden_states = conv(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + + for layer in self.time_res_stack: + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + return hidden_states + + +class Emu3VQVAEDecoder(nn.Module): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + self.time_res_stack = nn.ModuleList() + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=config.latent_channels, out_channels=config.latent_channels + ) + self.time_res_stack.append(time_res_conv) + + temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + for i in range(temp_upsample_block_num): + conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels) + self.time_conv.append(conv) + + self.conv_in = nn.Conv2d( + config.latent_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels) + self.up_block = Emu3VQVAEUpBlock(config) + + block_in = config.base_channels * config.channel_multiplier[0] + self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in) + self.conv_out = nn.Conv2d( + block_in, + config.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0) + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for layer in self.time_res_stack: + hidden_quant_states = layer(hidden_quant_states) + + for layer in self.time_conv: + hidden_quant_states = layer(hidden_quant_states) + hidden_quant_states *= torch.sigmoid(hidden_quant_states) + + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0) + hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:]) + quant_states = quant_states.reshape(-1, *quant_states.shape[2:]) + + hidden_states = self.conv_in(hidden_states) + + # middle & upsampling + hidden_states = self.middle_block(hidden_states, quant_states) + hidden_states = self.up_block(hidden_states, quant_states) + + hidden_states = self.norm_out(hidden_states, quant_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +EMU3_VQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Emu3VQVAEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). + """, + EMU3_VQ_START_DOCSTRING, +) +class Emu3VQVAE(PreTrainedModel): + config_class = Emu3VQVAEConfig + base_model_prefix = "emuvideovq" + main_input_name = "pixel_values" + _no_split_modules = [ + "Emu3VQVAETemporalResnetBlock", + "Emu3VQVAEAttentionBlock", + "Emu3VQVAEResnetBlock", + "Emu3VQVAEVectorQuantizer", + ] + + def _init_weights(self, module): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + self.config = config + + self.encoder = Emu3VQVAEEncoder(config) + self.decoder = Emu3VQVAEDecoder(config) + self.quantize = Emu3VQVAEVectorQuantizer(config) + self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1) + + self.quant_conv = Emu3VQVAEConv3d( + config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.post_quant_conv = Emu3VQVAEConv3d( + config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1) + self.eval() # Emu3's VQ model is frozen + + self.post_init() + + def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor): + is_image = pixel_values.ndim == 4 + if is_image: + temporal = self.config.temporal_downsample_factor + batch_size, channels, height, width = pixel_values.shape + pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1) + else: + batch_size, temporal, channels, height, width = pixel_values.shape + + hidden_states = self.encoder(pixel_values) + + # b t c h w -> b c t h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + hidden_states = self.quant_conv(hidden_states) + + # b c t h w -> b t c h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + codes = self.quantize(hidden_states) + + image_tokens = codes.squeeze(1) if is_image else codes + + image_tokens = [ + single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)] + for single_image, size in zip(image_tokens, image_sizes) + ] + + return image_tokens + + def decode(self, hidden_states: torch.Tensor): + is_image = hidden_states.ndim == 3 + if is_image: + hidden_states = hidden_states.unsqueeze(1) + + batch_size, temporal, height, width = hidden_states.shape + quant = self.quantize.embedding(hidden_states.flatten()) + + channels = quant.shape[-1] + quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous() + post_quant = self.post_quant_conv(quant) + + quant = quant.permute(0, 2, 1, 3, 4) + post_quant = post_quant.permute(0, 2, 1, 3, 4) + + video = self.decoder(post_quant, quant) + video = video.reshape( + batch_size, + temporal * self.config.temporal_downsample_factor, + self.config.out_channels, + height * self.spatial_scale_factor, + width * self.spatial_scale_factor, + ) + return video[:, 0] if is_image else video + + +class Emu3ImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.eol_token_id = vocab_map.get("<|extra_200|>") + self.image_token_id = vocab_map.get("") + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def image_tokens_str(self): + return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def img2bpe(self): + return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str} + + @cached_property + def bpe2img(self): + return {v: k for k, v in self.img2bpe.items()} + + @cached_property + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: List[torch.Tensor]) -> torch.Tensor: + device = img_batch.device + eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + img_tokens = torch.cat([img_tokens, eol_row], dim=-1) + return img_tokens.to(device) + + def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_batch = img_batch[..., :-1] # remove last row of EOL tokens + img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE): + _no_split_modules = [ + "Emu3DecoderLayer", + ] + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.get_text_config().initializer_range + if isinstance(module, Emu3VQVAE): + module.apply(module._init_weights) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +EMU3_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Has to be an instance of [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +EMU3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Has to be an instance of [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class Emu3TextModel(LlamaModel, Emu3PreTrainedModel): + def __init__(self, config: Emu3Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + +class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): + config_class = Emu3TextConfig + + def __init__(self, config): + super().__init__(config) + self.model = Emu3TextModel(config) + + @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") + def forward(**super_kwargs): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForCausalLM.from_pretrained("Emu3-community/Emu3-Chat-hf", torch_dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") + + >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + super().forward() + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + def __init__(self, config): + super().__init__(config) + self.text_model = Emu3ForCausalLM._from_config(config.text_config) + self.vqmodel = Emu3VQVAE(config.vq_config) + self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. + """ + image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes) + bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list] + bpe_tokens = torch.cat(bpe_tokens_list) + return bpe_tokens + + @torch.no_grad + def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): + """ + Decodes generated image tokens from language model to continuous pixel values + with VQGAN module via upsampling. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): + The tensors corresponding to the input images. + height (`int`): + Height of the generated image before upsampling. + width (`int`): + Width of the generated image before upsampling. + """ + sequences = image_tokens[:, :-3].view(-1, height, width + 1) + image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences) + image = self.vqmodel.decode(image_tokens) + return image + + @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + image_sizes: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForConditionalGeneration.from_pretrained("Emu3-community/Emu3-Chat-hf", torch_dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") + + >>> conversation = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."}, + ... ], + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "Please describe the image."}, + ... ], + ... }, + ... ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) + + >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None: + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs + + +__all__ = [ + "Emu3ForConditionalGeneration", + "Emu3ForCausalLM", + "Emu3TextModel", + "Emu3PreTrainedModel", + "Emu3VQVAE", +] diff --git a/src/transformers/models/emu3/processing_emu3.py b/src/transformers/models/emu3/processing_emu3.py new file mode 100644 index 00000000000000..2c536f5f24636f --- /dev/null +++ b/src/transformers/models/emu3/processing_emu3.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class Emu3TextKwargs(TextKwargs, total=False): + return_for_image_generation: bool + + +class Emu3ImagesKwargs(ImagesKwargs, total=False): + ratio: str + image_area: int + + +class Emu3ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: Emu3TextKwargs + images_kwargs: Emu3ImagesKwargs + _defaults = { + "text_kwargs": { + "return_for_image_generation": False, + }, + "images_kwargs": { + "ratio": "1:1", + "image_area": 518400, + }, + } + + +class Emu3Processor(ProcessorMixin): + r""" + Constructs a Emu3 processor which wraps a Emu3 image processor and a GPT2 tokenizer into a single + processor. + + [`Emu3Processor`] offers all the functionalities of [`Emu3ImageProcessor`] and [`GPT2TokenizerFast`]. + See the [`~Emu3Processor.__call__`] and [`~Emu3Processor.decode`] for more information. + + Args: + image_processor ([`Emu3ImageProcessor`]): + The image processor is a required input. + tokenizer ([`Emu3TokenizerFast`]): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") + image_processor_class = "Emu3ImageProcessor" + + def __init__( + self, + image_processor, + tokenizer, + chat_template=None, + **kwargs, + ): + self.image_token = tokenizer.image_token # image_token as placeholder to be replaced by vq-vae tokens + self.image_start_token = tokenizer.boi_token # "<|image start|>" fixed tokens for start and end of image + self.image_end_token = tokenizer.eoi_token # "<|image end|>" + self.fake_token_around_image = tokenizer.image_wrapper_token # "<|image token|>" every image starts with it + self.eof_token = tokenizer.eof_token # "<|extra_201|>" + self.bos_token = tokenizer.bos_token + self.downsample_ratio = 8 + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Emu3ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Emu3TokenizerFast's [`~Emu3TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + # check if images and text inputs are reversed for BC + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + output_kwargs = self._merge_kwargs( + Emu3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_image_generation = output_kwargs["text_kwargs"].pop("return_for_image_generation", False) + ratio = output_kwargs["images_kwargs"].pop("ratio", None) + image_area = output_kwargs["images_kwargs"].pop("image_area", None) + + if return_for_image_generation and images is not None: + raise ValueError("You should not provide `images` when `return_for_image_generation=True`") + + if not return_for_image_generation and text is None and images is None: + raise ValueError("You must provide either text or images when `return_for_image_generation=False`") + + image_features = {} + image_start_tokens = f"{self.image_start_token}" + image_end_tokens = f"{self.eof_token}{self.image_end_token}" + + # generate text from image + text input, so we add placeholders for image tokens + if not return_for_image_generation and images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + image_sizes = iter(image_features.image_sizes) + + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + height, width = image_size + height = height // self.downsample_ratio + width = width // self.downsample_ratio + image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code + + image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'' * image_seq_length}{image_end_tokens}" + sample = sample.replace(self.image_token, image_placeholder, 1) + sample = f"{self.bos_token}{sample}" # add BOS because PT tokenizer doesn't add it + prompt_strings.append(sample) + text = [sample.replace("", self.image_token) for sample in prompt_strings] + + # generate image from text input, so we add begin-of-image tokens from where image generation starts + elif return_for_image_generation: + height, width = self.calculate_generate_size(ratio, image_area, self.downsample_ratio) + image_prompt = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}" + text = [f"{self.bos_token}{sample}{image_prompt}" for sample in text] + image_features["image_sizes"] = [[height, width]] * len(text) + + # else just generate from text-only input, and we do no special treatment for text + data = self.tokenizer(text, **output_kwargs["text_kwargs"]) + data.update(**image_features) + + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]) + + def calculate_generate_size(self, ratio, image_area, spatial_factor): + width, height = map(int, ratio.split(":")) + current_area = width * height + target_ratio = (image_area / current_area) ** 0.5 + + token_height = int(round(height * target_ratio / spatial_factor)) + token_width = int(round(width * target_ratio / spatial_factor)) + return token_height, token_width + + def postprocess(self, images: ImageInput, **kwargs): + return self.image_processor.postprocess(images, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Emu3TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Emu3TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Emu3Processor"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 8c05d6093f521d..6a9cd232eb35dd 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3933,6 +3933,41 @@ def load_tf_weights_in_electra(*args, **kwargs): requires_backends(load_tf_weights_in_electra, ["torch"]) +class Emu3ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Emu3ForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Emu3PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Emu3TextModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Emu3VQVAE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EncodecModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index e9e87a4b3d44b7..2fcc7f172054f9 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -233,6 +233,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class Emu3ImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class FlavaFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c19e0cc4fbd795..ad2aeacbfebd15 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1626,7 +1626,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams): # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` pixel_values_is_mutually_exclusive = any( model_name in model_class.__name__.lower() - for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma"] + for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"] ) if pixel_values_is_mutually_exclusive: inputs_dict.pop("pixel_values", None) @@ -1700,6 +1700,18 @@ def test_generate_from_inputs_embeds_with_static_cache(self): if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): self.skipTest(reason="This model does not support `inputs_embeds` in generation") + # Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the + # exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the + # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` + pixel_values_is_mutually_exclusive = any( + model_name in model_class.__name__.lower() + for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"] + ) + if pixel_values_is_mutually_exclusive: + inputs_dict.pop("pixel_values", None) + inputs_dict.pop("pixel_values_videos", None) + inputs_dict.pop("pixel_values_images", None) + input_ids = inputs_dict.pop("input_ids") model.config.use_cache = True @@ -1941,6 +1953,10 @@ def test_generate_with_static_cache(self): for dtype in (torch.float32, torch.float16): model = model_class(config).to(torch_device).to(dtype).eval() + inputs_dict = { + k: v.to(dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v + for k, v in inputs_dict.items() + } set_model_for_less_flaky_test(model) generation_kwargs = { diff --git a/tests/models/emu3/__init__.py b/tests/models/emu3/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py new file mode 100644 index 00000000000000..d1c4501c5e8bd0 --- /dev/null +++ b/tests/models/emu3/test_modeling_emu3.py @@ -0,0 +1,550 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch emu3 model.""" + +import unittest + +import numpy as np +import requests +from huggingface_hub import hf_hub_download +from parameterized import parameterized + +from transformers import Emu3Config, Emu3TextConfig, is_torch_available, is_vision_available, set_seed +from transformers.testing_utils import ( + require_bitsandbytes, + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_vision_available(): + from PIL import Image + +if is_torch_available(): + import torch + + from transformers import ( + Emu3ForCausalLM, + Emu3ForConditionalGeneration, + Emu3Processor, + Emu3TextModel, + ) + + +class Emu3Text2TextModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=False, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + intermediate_size=37, + max_position_embeddings=512, + initializer_range=0.02, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = input_ids.ne(1).to(torch_device) + + config = self.get_config() + + return config, input_ids, attention_mask + + def get_config(self): + return Emu3TextConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + max_position_embeddings=self.max_position_embeddings, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_torch +class Emu3Text2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Emu3ForCausalLM,) if is_torch_available() else () + all_generative_model_classes = (Emu3ForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "text-generation": Emu3ForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False + + def setUp(self): + self.model_tester = Emu3Text2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Emu3TextConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = Emu3TextModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = Emu3TextModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + @unittest.skip("Doesn't work, tensors are not almost same") # TODO raushan fixme + def test_custom_4d_attention_mask(self): + pass + + @unittest.skip("Fails with unknown error only on end-to-end compile") # TODO raushan fixme + def test_generate_compile_1_end_to_end(self): + pass + + +class Emu3Vision2TextModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=False, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + intermediate_size=37, + max_position_embeddings=512, + initializer_range=0.02, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + image_token_id=3, + image_size=30, + codebook_size=20, + temporal_downsample_factor=1, + base_channels=32, + vq_channel_multiplier=[1, 1], + image_seq_length=100, + vq_img_token_start_id=3, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.image_token_id = image_token_id + self.image_size = image_size + self.codebook_size = codebook_size + self.temporal_downsample_factor = temporal_downsample_factor + self.vq_channel_multiplier = vq_channel_multiplier + self.vq_img_token_start_id = vq_img_token_start_id + self.base_channels = base_channels + self.seq_length = seq_length + image_seq_length + self.image_seq_length = image_seq_length + + def prepare_config_and_inputs(self): + config = self.get_config() + + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size) + attention_mask = input_ids.ne(1).to(torch_device) + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[:, : self.image_seq_length] = self.image_token_id + + pixel_values = floats_tensor( + [ + self.batch_size, + 3, + self.image_size, + self.image_size, + ] + ) + image_sizes = [[self.image_size, self.image_size]] * self.batch_size + image_sizes = torch.tensor(image_sizes, device=torch_device, dtype=torch.int64) + + return config, input_ids, attention_mask, pixel_values, image_sizes + + def get_config(self): + # create dummy vocab map for image2bpe mapping if it needs remapping + # we assume that vocab size is big enough to account for `codebook_size` amount of + # image tokens somewhere at the beginning of total vocab size + + vocab_map = {i: chr(i) for i in range(self.vocab_size)} + start = self.vq_img_token_start_id + end = self.vq_img_token_start_id + self.codebook_size + for i in range(start, end): + # dummy str for each token, anything that fits pattern "<|visual token XXXXXX|>" + vocab_map[i] = f"<|visual token{i:06d}|>" + + # add tokens that have to be in the vocab, we'll retrieve their ids later in modeling code + vocab_map[self.image_token_id] = "" + vocab_map[self.image_token_id + 1] = "<|extra_200|>" + vocab_map = {v: k for k, v in vocab_map.items()} + + text_config = Emu3TextConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + ) + + vq_config = { + "codebook_size": self.codebook_size, + "temporal_downsample_factor": self.temporal_downsample_factor, + "base_channels": self.base_channels, + "channel_multiplier": self.vq_channel_multiplier, + "hidden_size": self.base_channels, + } + return Emu3Config(text_config=text_config, vq_config=vq_config, vocabulary_map=vocab_map) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + pixel_values, + image_sizes, + ) = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_sizes": image_sizes, + } + return config, inputs_dict + + +@require_torch +class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Emu3ForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (Emu3ForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {} + test_headmasking = False + test_pruning = False + fx_compatible = False + + def setUp(self): + self.model_tester = Emu3Vision2TextModelTester(self) + self.config_tester = ConfigTester( + self, config_class=Emu3Config, has_text_modality=False, common_properties=["vocabulary_map"] + ) + + def test_config(self): + self.config_tester.run_common_tests() + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + self.assertTrue(torch.allclose(out_embeds, out_ids)) + + @unittest.skip( + "Emu3 has a VQ module that uses `weight.data` directly in forward which prevent offloding on that module" + ) + def test_disk_offload_safetensors(self): + pass + + @unittest.skip( + "Emu3 has a VQ module that uses `weight.data` directly in forward which prevent offloding on that module" + ) + def test_disk_offload_bin(self): + pass + + @unittest.skip( + "Emu3 has a VQ module that uses `weight.data` directly in forward which prevent offloding on that module" + ) + def test_cpu_offload(self): + pass + + @unittest.skip("Doesn't work, tensors are not almost same") # TODO raushan fixme + def test_custom_4d_attention_mask(self): + pass + + @unittest.skip("VQ-VAE module doesn't initialize weights properly") + def test_initialization(self): + pass + + @unittest.skip("End-to-end compilation is not supported due to dynamic control in `prepare_inputs_for_generation`") + def test_generate_compile_1_end_to_end(self): + pass + + +@require_torch +class Emu3IntegrationTest(unittest.TestCase): + @slow + @require_bitsandbytes + def test_model_generation(self): + model = Emu3ForConditionalGeneration.from_pretrained( + "Emu3-community/Emu3-Chat-hf", load_in_4bit=True, device_map="auto" + ) + processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") + + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + prompt = "USER: Describe what do you see here and tell me about the history behind it? ASSISTANT:" + + inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16) + + # greedy generation outputs + EXPECTED_TEXT_COMPLETION = ['USER: 114*143Describe what do you see here and tell me about the history behind it? ASSISTANT: The image depicts the constellation of Ursa Minor, also known as the Little Bear. This constellation was one of the 24 modern constellations introduced by Charles Messier in 178'] # fmt: skip + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) + text = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_bitsandbytes + def test_model_generation_batched(self): + model = Emu3ForConditionalGeneration.from_pretrained( + "Emu3-community/Emu3-Chat-hf", load_in_4bit=True, device_map="auto" + ) + processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") + processor.tokenizer.padding_side = "left" + + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + image_2 = Image.open( + requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw + ) + prompts = [ + "USER: Describe what do you see here and tell me about the history behind it? ASSISTANT:", + "USER: What do you know about the constellation in this image? ASSISTANT:", + ] + + inputs = processor(images=[image, image_2], text=prompts, padding=True, return_tensors="pt").to( + model.device, torch.float16 + ) + + # greedy generation outputs + EXPECTED_TEXT_COMPLETION = [ + 'USER: 114*143Describe what do you see here and tell me about the history behind it? ASSISTANT: The image depicts the constellation of Ursa Minor, also known as the Little Bear. This constellation was one of the 24 modern constellations introduced by Charles Messier in 178', + 'USER: 75*125What do you know about the constellation in this image? ASSISTANT: The image shows a segment of a wire rope, characterized by its consistent pattern and regular twists, indicative of a high-quality, well-made rope. This type of detail suggests careful manufacturing processes and attention to' + ] # fmt: skip + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) + text = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_bitsandbytes + def test_model_generation_multi_image(self): + model = Emu3ForConditionalGeneration.from_pretrained( + "Emu3-community/Emu3-Chat-hf", load_in_4bit=True, device_map="auto" + ) + processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") + + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + image_2 = Image.open( + requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw + ) + prompt = "USER: What do these two images have in common? ASSISTANT:" + + inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.float16) + + # greedy generation outputs + EXPECTED_TEXT_COMPLETION = ['USER: 114*14375*125What do these two images have in common? ASSISTANT: The two images both depict a geometric shape - a triangle in the larger image and a line segment in the smaller image. They share a common feature of being created with a series of connected dots, which'] # fmt: skip + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) + text = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_bitsandbytes + def test_model_generate_images(self): + model = Emu3ForConditionalGeneration.from_pretrained( + "Emu3-community/Emu3-Gen-hf", load_in_4bit=True, device_map="auto" + ) + processor = Emu3Processor.from_pretrained("Emu3-community/Emu3-Chat-hf") + + inputs = processor( + text=["a portrait of young girl. masterpiece, film grained, best quality."], + padding=True, + return_tensors="pt", + return_for_image_generation=True, + ).to(model.device) + self.assertTrue(inputs.input_ids.shape[1] == 23) + + image_sizes = inputs.pop("image_sizes") + HEIGHT, WIDTH = image_sizes[0] + VISUAL_TOKENS = model.vocabulary_mapping.image_tokens + + def prefix_allowed_tokens_fn(batch_id, input_ids): + height, width = HEIGHT, WIDTH + visual_tokens = VISUAL_TOKENS + image_wrapper_token_id = torch.tensor([processor.tokenizer.image_wrapper_token_id], device=model.device) + eoi_token_id = torch.tensor([processor.tokenizer.eoi_token_id], device=model.device) + eos_token_id = torch.tensor([processor.tokenizer.eos_token_id], device=model.device) + pad_token_id = torch.tensor([processor.tokenizer.pad_token_id], device=model.device) + eof_token_id = torch.tensor([processor.tokenizer.eof_token_id], device=model.device) + eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0] + + position = torch.nonzero(input_ids == image_wrapper_token_id, as_tuple=True)[0][0] + offset = input_ids.shape[0] - position + if offset % (width + 1) == 0: + return (eol_token_id,) + elif offset == (width + 1) * height + 1: + return (eof_token_id,) + elif offset == (width + 1) * height + 2: + return (eoi_token_id,) + elif offset == (width + 1) * height + 3: + return (eos_token_id,) + elif offset > (width + 1) * height + 3: + return (pad_token_id,) + else: + return visual_tokens + + out = model.generate( + **inputs, + max_new_tokens=50_000, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + do_sample=False, + ) + self.assertTrue(out.shape[1] == 8216) + + image = model.decode_image_tokens(out[:, inputs.input_ids.shape[1] :], height=HEIGHT, width=WIDTH) + images = processor.postprocess(list(image.float()), return_tensors="np") + self.assertTrue(images["pixel_values"].shape == (3, 720, 720)) + self.assertTrue(isinstance(images["pixel_values"], np.ndarray)) + + filepath = hf_hub_download( + repo_id="raushan-testing-hf/images_test", + filename="emu3_generated_pixels.npy", + repo_type="dataset", + ) + original_pixels = np.load(filepath) + self.assertTrue(np.allclose(original_pixels, images["pixel_values"])) diff --git a/tests/models/emu3/test_processor_emu3.py b/tests/models/emu3/test_processor_emu3.py new file mode 100644 index 00000000000000..7bc77075b1a69b --- /dev/null +++ b/tests/models/emu3/test_processor_emu3.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch emu3 model.""" + +import tempfile +import unittest + +import numpy as np + +from transformers import Emu3Processor, GPT2TokenizerFast +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import Emu3ImageProcessor + + +class Emu3ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Emu3Processor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = Emu3ImageProcessor() + extra_special_tokens = extra_special_tokens = { + "image_token": "", + "boi_token": "<|image start|>", + "eoi_token": "<|image end|>", + "image_wrapper_token": "<|image token|>", + "eof_token": "<|extra_201|>", + } + tokenizer = GPT2TokenizerFast.from_pretrained( + "openai-community/gpt2", extra_special_tokens=extra_special_tokens + ) + tokenizer.pad_token_id = 0 + tokenizer.sep_token_id = 1 + processor = self.processor_class( + image_processor=image_processor, tokenizer=tokenizer, chat_template="dummy_template" + ) + processor.save_pretrained(self.tmpdirname) + + def test_processor_for_generation(self): + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + + # we don't need an image as input because the model will generate one + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor(text=input_str, return_for_image_generation=True, return_tensors="pt") + self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask", "image_sizes"]) + self.assertEqual(inputs[self.text_input_name].shape[-1], 8) + + # when `return_for_image_generation` is set, we raise an error that image should not be provided + with self.assertRaises(ValueError): + inputs = processor( + text=input_str, images=image_input, return_for_image_generation=True, return_tensors="pt" + ) + + def test_processor_postprocess(self): + processor_components = self.prepare_components() + processor = self.processor_class(**processor_components) + + input_str = "lower newer" + orig_image_input = self.prepare_image_inputs() + orig_image = np.array(orig_image_input).transpose(2, 0, 1) + + inputs = processor(text=input_str, images=orig_image, do_resize=False, return_tensors="np") + normalized_image_input = inputs.pixel_values + unnormalized_images = processor.postprocess(normalized_image_input, return_tensors="np")["pixel_values"] + + # For an image where pixels go from 0 to 255 the diff can be 1 due to some numerical precision errors when scaling and unscaling + self.assertTrue(np.abs(orig_image - unnormalized_images).max() >= 1) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c57ef6ed05f166..922d5b82725218 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3877,7 +3877,7 @@ def test_sdpa_can_dispatch_non_composite_models(self): for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") + raise ValueError(f"The eager model should not have SDPA attention layers but got {class_name}") @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 9be492476c1e2d..6262b1902aabc3 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -266,6 +266,10 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s f"config.{attribute}" in modeling_source or f'getattr(config, "{attribute}"' in modeling_source or f'getattr(self.config, "{attribute}"' in modeling_source + or ( + "TextConfig" in config_class.__name__ + and f"config.get_text_config().{attribute}" in modeling_source + ) ): attribute_used = True # Deal with multi-line cases diff --git a/utils/check_repo.py b/utils/check_repo.py index d20760bcf75eba..7f3e0c66d55ed0 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -139,6 +139,8 @@ "Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration. "MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests "MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests + "Emu3VQVAE", # Building part of bigger (tested) model + "Emu3TextModel", # Building part of bigger (tested) model ] ) @@ -333,6 +335,8 @@ "VitPoseForPoseEstimation", "CLIPTextModel", "MoshiForConditionalGeneration", # no auto class for speech-to-speech + "Emu3VQVAE", # no autoclass for VQ-VAE models + "Emu3TextModel", # Building part of bigger (tested) model ] # DO NOT edit this list!