diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c4707d5f20a027..2cf35bf425793f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -844,6 +844,8 @@ title: FLAVA - local: model_doc/git title: GIT + - local: model_doc/got_ocr2 + title: GOT-OCR2 - local: model_doc/grounding-dino title: Grounding DINO - local: model_doc/groupvit diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 49c44874e320ef..030eaecd7f4074 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -155,6 +155,7 @@ Flax), PyTorch, and/or TensorFlow. | [GIT](model_doc/git) | ✅ | ❌ | ❌ | | [GLM](model_doc/glm) | ✅ | ❌ | ❌ | | [GLPN](model_doc/glpn) | ✅ | ❌ | ❌ | +| [GOT-OCR2](model_doc/got_ocr2) | ✅ | ❌ | ❌ | | [GPT Neo](model_doc/gpt_neo) | ✅ | ❌ | ✅ | | [GPT NeoX](model_doc/gpt_neox) | ✅ | ❌ | ❌ | | [GPT NeoX Japanese](model_doc/gpt_neox_japanese) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/got_ocr2.md b/docs/source/en/model_doc/got_ocr2.md new file mode 100644 index 00000000000000..1cf1427990dcb4 --- /dev/null +++ b/docs/source/en/model_doc/got_ocr2.md @@ -0,0 +1,268 @@ + + +# GOT-OCR2 + +## Overview + +The GOT-OCR2 model was proposed in [General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model](https://arxiv.org/abs/2409.01704) by Haoran Wei, Chenglong Liu, Jinyue Chen, Jia Wang, Lingyu Kong, Yanming Xu, Zheng Ge, Liang Zhao, Jianjian Sun, Yuang Peng, Chunrui Han, Xiangyu Zhang. + +The abstract from the paper is the following: + +*Traditional OCR systems (OCR-1.0) are increasingly unable to meet people’snusage due to the growing demand for intelligent processing of man-made opticalncharacters. In this paper, we collectively refer to all artificial optical signals (e.g., plain texts, math/molecular formulas, tables, charts, sheet music, and even geometric shapes) as "characters" and propose the General OCR Theory along with an excellent model, namely GOT, to promote the arrival of OCR-2.0. The GOT, with 580M parameters, is a unified, elegant, and end-to-end model, consisting of a high-compression encoder and a long-contexts decoder. As an OCR-2.0 model, GOT can handle all the above "characters" under various OCR tasks. On the input side, the model supports commonly used scene- and document-style images in slice and whole-page styles. On the output side, GOT can generate plain or formatted results (markdown/tikz/smiles/kern) via an easy prompt. Besides, the model enjoys interactive OCR features, i.e., region-level recognition guided by coordinates or colors. Furthermore, we also adapt dynamic resolution and multipage OCR technologies to GOT for better practicality. In experiments, we provide sufficient results to prove the superiority of our model.* + + + + GOT-OCR2 training stages. Taken from the original paper. + + +Tips: + +GOT-OCR2 works on a wide range of tasks, including plain document OCR, scene text OCR, formatted document OCR, and even OCR for tables, charts, mathematical formulas, geometric shapes, molecular formulas and sheet music. While this implementation of the model will only output plain text, the outputs can be further processed to render the desired format, with packages like `pdftex`, `mathpix`, `matplotlib`, `tikz`, `verovio` or `pyecharts`. +The model can also be used for interactive OCR, where the user can specify the region to be recognized by providing the coordinates or the color of the region's bounding box. + +This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan). +The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2.0). + +## Usage example + +### Plain text inference + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText + +>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda") +>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + +>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" +>>> inputs = processor(image, return_tensors="pt").to("cuda") + +>>> generate_ids = model.generate( +>>> **inputs, +>>> do_sample=False, +>>> tokenizer=processor.tokenizer, +>>> stop_strings="<|im_end|>", +>>> max_new_tokens=4096, +>>> ) + +>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) +"R&D QUALITY IMPROVEMENT\nSUGGESTION/SOLUTION FORM\nName/Phone Ext. : (...)" +``` + +### Plain text inference batched + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText + +>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") +>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + +>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" +>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" + +>>> inputs = processor([image1, image2], return_tensors="pt") + +>>> generate_ids = model.generate( +>>> **inputs, +>>> do_sample=False, +>>> tokenizer=processor.tokenizer, +>>> stop_strings="<|im_end|>", +>>> max_new_tokens=4, +>>> ) + +>>> processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True) +["Reducing the number", "R&D QUALITY"] +``` + +### Formatted text inference + +GOT-OCR2 can also generate formatted text, such as markdown or LaTeX. Here is an example of how to generate formatted text: + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText + +>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda") +>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + +>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png" +>>> inputs = processor(image, return_tensors="pt", format=True).to("cuda") + +>>> generate_ids = model.generate( +>>> **inputs, +>>> do_sample=False, +>>> tokenizer=processor.tokenizer, +>>> stop_strings="<|im_end|>", +>>> max_new_tokens=4096, +>>> ) + +>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) +"\\author{\nHanwen Jiang* \\(\\quad\\) Arjun Karpur \\({ }^{\\dagger} \\quad\\) Bingyi Cao \\({ }^{\\dagger} \\quad\\) (...)" +``` + +### Inference on multiple pages + +Although it might be reasonable in most cases to use a “for loop” for multi-page processing, some text data with formatting across several pages make it necessary to process all pages at once. GOT introduces a multi-page OCR (without “for loop”) feature, where multiple pages can be processed by the model at once, whith the output being one continuous text. +Here is an example of how to process multiple pages at once: + + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText + +>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda") +>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + +>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png" +>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png" +>>> inputs = processor([image1, image2], return_tensors="pt", format=True).to("cuda") + +>>> generate_ids = model.generate( +>>> **inputs, +>>> do_sample=False, +>>> tokenizer=processor.tokenizer, +>>> stop_strings="<|im_end|>", +>>> max_new_tokens=4096, +>>> ) + +>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) +"\\title{\nGeneral OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model\n}\n\\author{\nHaoran Wei (...)" +``` + +### Inference on cropped patches + +GOT supports a 1024×1024 input resolution, which is sufficient for most OCR tasks, such as scene OCR or processing A4-sized PDF pages. However, certain scenarios, like horizontally stitched two-page PDFs commonly found in academic papers or images with unusual aspect ratios, can lead to accuracy issues when processed as a single image. To address this, GOT can dynamically crop an image into patches, process them all at once, and merge the results for better accuracy with such inputs. +Here is an example of how to process cropped patches: + +```python +>>> import torch +>>> from transformers import AutoProcessor, AutoModelForImageTextToText + +>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16).to("cuda") +>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + +>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png" +>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3).to("cuda") + +>>> generate_ids = model.generate( +>>> **inputs, +>>> do_sample=False, +>>> tokenizer=processor.tokenizer, +>>> stop_strings="<|im_end|>", +>>> max_new_tokens=4096, +>>> ) + +>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) +"on developing architectural improvements to make learnable matching methods generalize.\nMotivated by the above observations, (...)" +``` + +### Inference on a specific region + +GOT supports interactive OCR, where the user can specify the region to be recognized by providing the coordinates or the color of the region's bounding box. Here is an example of how to process a specific region: + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText + +>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda") +>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + +>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" +>>> inputs = processor(image, return_tensors="pt", color="green") # or box=[x1, y1, x2, y2] for coordinates (image pixels) +>>> inputs = inputs.to("cuda") + +>>> generate_ids = model.generate( +>>> **inputs, +>>> do_sample=False, +>>> tokenizer=processor.tokenizer, +>>> stop_strings="<|im_end|>", +>>> max_new_tokens=4096, +>>> ) + +>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) +"You should keep in mind what features from the module should be used, especially \nwhen you’re planning to sell a template." +``` + +### Inference on general OCR data example: sheet music + +Although this implementation of the model will only output plain text, the outputs can be further processed to render the desired format, with packages like `pdftex`, `mathpix`, `matplotlib`, `tikz`, `verovio` or `pyecharts`. +Here is an example of how to process sheet music: + +```python +>>> from transformers import AutoProcessor, AutoModelForImageTextToText +>>> import verovio + +>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda") +>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + +>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png" +>>> inputs = processor(image, return_tensors="pt", format=True).to("cuda") + +>>> generate_ids = model.generate( +>>> **inputs, +>>> do_sample=False, +>>> tokenizer=processor.tokenizer, +>>> stop_strings="<|im_end|>", +>>> max_new_tokens=4096, +>>> ) + +>>> outputs = processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) +>>> tk = verovio.toolkit() +>>> tk.loadData(outputs) +>>> tk.setOptions( +>>> { +>>> "pageWidth": 2100, +>>> "pageHeight": 800, +>>> "footer": "none", +>>> "barLineWidth": 0.5, +>>> "beamMaxSlope": 15, +>>> "staffLineWidth": 0.2, +>>> "spacingStaff": 6, +>>> } +>>> ) +>>> tk.getPageCount() +>>> svg = tk.renderToSVG() +>>> svg = svg.replace('overflow="inherit"', 'overflow="visible"') +>>> with open("output.svg", "w") as f: +>>> f.write(svg) +``` + + +## GotOcr2Config + +[[autodoc]] GotOcr2Config + +## GotOcr2VisionConfig + +[[autodoc]] GotOcr2VisionConfig + +## GotOcr2ImageProcessor + +[[autodoc]] GotOcr2ImageProcessor + +## GotOcr2Processor + +[[autodoc]] GotOcr2Processor + +## GotOcr2Model + +[[autodoc]] GotOcr2Model + - forward + +## GotOcr2ForConditionalGeneration + +[[autodoc]] GotOcr2ForConditionalGeneration + - forward + diff --git a/docs/source/en/model_doc/qwen2_vl.md b/docs/source/en/model_doc/qwen2_vl.md index 7c864b860bd8ea..f31c5ebc79765a 100644 --- a/docs/source/en/model_doc/qwen2_vl.md +++ b/docs/source/en/model_doc/qwen2_vl.md @@ -18,7 +18,7 @@ rendered properly in your Markdown viewer. ## Overview -The [Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/) model is a major update to [Qwen-VL](https://arxiv.org/pdf/2308.12966) from the Qwen team at Alibaba Research. +The [Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/) model is a major update to [Qwen-VL](https://arxiv.org/pdf/2308.12966) from the Qwen team at Alibaba Research. The abstract from the blog is the following: @@ -231,7 +231,7 @@ In case of limited GPU RAM, one can reduce the resolution as follows: ```python min_pixels = 256*28*28 -max_pixels = 1024*28*28 +max_pixels = 1024*28*28 processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) ``` This ensures each image gets encoded using a number between 256-1024 tokens. The 28 comes from the fact that the model uses a patch size of 14 and a temporal patch size of 2 (14 x 2 = 28). @@ -245,7 +245,7 @@ conversation = [ { "role": "user", "content": [ - {"type": "image"}, + {"type": "image"}, {"type": "text", "text": "Hello, how are you?"} ] }, @@ -256,10 +256,10 @@ conversation = [ { "role": "user", "content": [ - {"type": "text", "text": "Can you describe these images and video?"}, - {"type": "image"}, - {"type": "image"}, - {"type": "video"}, + {"type": "text", "text": "Can you describe these images and video?"}, + {"type": "image"}, + {"type": "image"}, + {"type": "video"}, {"type": "text", "text": "These are from my vacation."} ] }, @@ -300,8 +300,8 @@ To load and run a model using Flash Attention-2, simply add `attn_implementation from transformers import Qwen2VLForConditionalGeneration model = Qwen2VLForConditionalGeneration.from_pretrained( - "Qwen/Qwen2-VL-7B-Instruct", - torch_dtype=torch.bfloat16, + "Qwen/Qwen2-VL-7B-Instruct", + torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) ``` diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 4d7852a66307e2..a38983e7f339f5 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -49,6 +49,7 @@ FlashAttention-2 is currently supported for the following architectures: * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) +* [GotOcr2](https://huggingface.co/docs/transformers/model_doc/got_ocr2#transformers.GotOcr2Model) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel) @@ -239,6 +240,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) +* [GotOcr2](https://huggingface.co/docs/transformers/model_doc/got_ocr2#transformers.GotOcr2Model) * [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1eb34b48fda856..1d8cd3ece6faf3 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -463,6 +463,11 @@ ], "models.glm": ["GlmConfig"], "models.glpn": ["GLPNConfig"], + "models.got_ocr2": [ + "GotOcr2Config", + "GotOcr2Processor", + "GotOcr2VisionConfig", + ], "models.gpt2": [ "GPT2Config", "GPT2Tokenizer", @@ -1210,6 +1215,7 @@ _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"]) _import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"]) _import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"]) + _import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"]) _import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"]) _import_structure["models.idefics"].extend(["IdeficsImageProcessor"]) _import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"]) @@ -2354,6 +2360,13 @@ "GLPNPreTrainedModel", ] ) + _import_structure["models.got_ocr2"].extend( + [ + "GotOcr2ForConditionalGeneration", + "GotOcr2Model", + "GotOcr2PreTrainedModel", + ] + ) _import_structure["models.gpt2"].extend( [ "GPT2DoubleHeadsModel", @@ -5384,6 +5397,7 @@ ) from .models.glm import GlmConfig from .models.glpn import GLPNConfig + from .models.got_ocr2 import GotOcr2Config, GotOcr2Processor, GotOcr2VisionConfig from .models.gpt2 import ( GPT2Config, GPT2Tokenizer, @@ -6171,6 +6185,7 @@ ) from .models.fuyu import FuyuImageProcessor, FuyuProcessor from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor + from .models.got_ocr2 import GotOcr2ImageProcessor from .models.grounding_dino import GroundingDinoImageProcessor from .models.idefics import IdeficsImageProcessor from .models.idefics2 import Idefics2ImageProcessor @@ -7146,6 +7161,11 @@ GLPNModel, GLPNPreTrainedModel, ) + from .models.got_ocr2 import ( + GotOcr2ForConditionalGeneration, + GotOcr2Model, + GotOcr2PreTrainedModel, + ) from .models.gpt2 import ( GPT2DoubleHeadsModel, GPT2ForQuestionAnswering, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 2e3b48da96e966..1e5f91273829c9 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -101,6 +101,7 @@ git, glm, glpn, + got_ocr2, gpt2, gpt_bigcode, gpt_neo, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1d9db837e8d27c..fc25dc040c8fa5 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -119,6 +119,7 @@ ("git", "GitConfig"), ("glm", "GlmConfig"), ("glpn", "GLPNConfig"), + ("got_ocr2", "GotOcr2Config"), ("gpt-sw3", "GPT2Config"), ("gpt2", "GPT2Config"), ("gpt_bigcode", "GPTBigCodeConfig"), @@ -429,6 +430,7 @@ ("git", "GIT"), ("glm", "GLM"), ("glpn", "GLPN"), + ("got_ocr2", "GOT-OCR2"), ("gpt-sw3", "GPT-Sw3"), ("gpt2", "OpenAI GPT-2"), ("gpt_bigcode", "GPTBigCode"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index db25591eaa3544..89d546946f5c3c 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -87,6 +87,7 @@ ("fuyu", ("FuyuImageProcessor",)), ("git", ("CLIPImageProcessor",)), ("glpn", ("GLPNImageProcessor",)), + ("got_ocr2", ("GotOcr2ImageProcessor",)), ("grounding-dino", ("GroundingDinoImageProcessor",)), ("groupvit", ("CLIPImageProcessor",)), ("hiera", ("BitImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index bec72a4e7b84ec..fea4a94f55df08 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -116,6 +116,7 @@ ("git", "GitModel"), ("glm", "GlmModel"), ("glpn", "GLPNModel"), + ("got_ocr2", "GotOcr2Model"), ("gpt-sw3", "GPT2Model"), ("gpt2", "GPT2Model"), ("gpt_bigcode", "GPTBigCodeModel"), @@ -497,6 +498,7 @@ ("gemma2", "Gemma2ForCausalLM"), ("git", "GitForCausalLM"), ("glm", "GlmForCausalLM"), + ("got_ocr2", "GotOcr2ForConditionalGeneration"), ("gpt-sw3", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"), ("gpt_bigcode", "GPTBigCodeForCausalLM"), @@ -783,6 +785,7 @@ ("chameleon", "ChameleonForConditionalGeneration"), ("fuyu", "FuyuForCausalLM"), ("git", "GitForCausalLM"), + ("got_ocr2", "GotOcr2ForConditionalGeneration"), ("idefics", "IdeficsForVisionText2Text"), ("idefics2", "Idefics2ForConditionalGeneration"), ("idefics3", "Idefics3ForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 3e475b1be211fa..5a19dabad06381 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -61,6 +61,7 @@ ("flava", "FlavaProcessor"), ("fuyu", "FuyuProcessor"), ("git", "GitProcessor"), + ("got_ocr2", "GotOcr2Processor"), ("grounding-dino", "GroundingDinoProcessor"), ("groupvit", "CLIPProcessor"), ("hubert", "Wav2Vec2Processor"), diff --git a/src/transformers/models/got_ocr2/__init__.py b/src/transformers/models/got_ocr2/__init__.py new file mode 100644 index 00000000000000..a219ba80efcba0 --- /dev/null +++ b/src/transformers/models/got_ocr2/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_got_ocr2 import * + from .image_processing_got_ocr2 import * + from .modeling_got_ocr2 import * + from .processing_got_ocr2 import * + + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/got_ocr2/configuration_got_ocr2.py b/src/transformers/models/got_ocr2/configuration_got_ocr2.py new file mode 100644 index 00000000000000..5307a0e62ff530 --- /dev/null +++ b/src/transformers/models/got_ocr2/configuration_got_ocr2.py @@ -0,0 +1,296 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_got_ocr2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class GotOcr2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GotOcr2VisionModel`]. It is used to instantiate a GOT_OCR2 + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of the SAM ViT-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio * + hidden_size`. + """ + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + mlp_dim=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.mlp_dim = mlp_dim + + +class GotOcr2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GotOcr2Model`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the GotOcr2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GotOcr2Model`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import GotOcr2ForConditionalGeneration, GotOcr2Config + + >>> # Initializing a GotOcr2 style configuration + >>> configuration = GotOcr2Config() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = GotOcr2ForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "got_ocr2" + sub_configs = {"vision_config": GotOcr2VisionConfig} + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = GotOcr2VisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = GotOcr2VisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["GotOcr2VisionConfig", "GotOcr2Config"] diff --git a/src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py b/src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py new file mode 100644 index 00000000000000..4eca2ad6122b7c --- /dev/null +++ b/src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py @@ -0,0 +1,292 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gc +import glob +import os +from typing import List, Optional + +import regex as re +import torch +from huggingface_hub import snapshot_download +from safetensors import safe_open + +from transformers import ( + GotOcr2Config, + GotOcr2ForConditionalGeneration, + GotOcr2ImageProcessor, + GotOcr2Processor, + PreTrainedTokenizerFast, + is_vision_available, +) +from transformers.convert_slow_tokenizer import TikTokenConverter +from transformers.tokenization_utils import AddedToken + + +if is_vision_available(): + from transformers.image_utils import load_image + + +# fmt: off +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Vision encoder mapping + r"model.vision_tower_high.pos_embed": r"visual.pos_embed", + r"model.vision_tower_high.patch_embed.proj": r"visual.patch_embed.projection", + r"model.vision_tower_high.blocks.(\d+).norm": r"visual.layers.\1.layer_norm", + r"model.vision_tower_high.blocks.(\d+).attn": r"visual.layers.\1.attn", + r"model.vision_tower_high.blocks.(\d+).mlp": r"visual.layers.\1.mlp", + r"model.vision_tower_high.neck.0": r"visual.neck.conv1", + r"model.vision_tower_high.neck.1": r"visual.neck.layer_norm1", + r"model.vision_tower_high.neck.2": r"visual.neck.conv2", + r"model.vision_tower_high.neck.3": r"visual.neck.layer_norm2", + r"model.vision_tower_high.net_(\d+)": lambda m: f"visual_adapter.conv_upsampler{int(m.group(1)) - 1}", + r"model.mm_projector_vary" : r"visual_adapter.multimodal_projector", +} +# fmt: on + +CONTEXT_LENGTH = 8000 + + +def convert_old_keys_to_new_keys(state_dict_keys: dict = None): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def load_original_state_dict(model_id): + directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) + + original_state_dict = {} + for path in glob.glob(f"{directory_path}/*"): + if path.endswith(".safetensors"): + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + original_state_dict[key] = f.get_tensor(key) + + return original_state_dict + + +def get_got_ocr2_config(): + config = GotOcr2Config( + vocab_size=151860, + hidden_size=1024, + intermediate_size=2816, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + tie_word_embeddings=True, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=32768, + max_window_layers=21, + attention_dropout=0.0, + rope_scaling=None, + image_token_id=151859, + ) + + return config + + +def write_model( + model_path, + input_base_path, + push_to_hub=False, +): + os.makedirs(model_path, exist_ok=True) + + config = get_got_ocr2_config() + config.architectures = ["GotOcr2ForConditionalGeneration"] + config.save_pretrained(model_path) + print("Model config saved successfully...") + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + print(f"Fetching all parameters from the checkpoint at {input_base_path}...") + state_dict_old = load_original_state_dict(input_base_path) + print("Converting model...") + all_keys = list(state_dict_old.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + state_dict = {} + for key in all_keys: + new_key = new_keys[key] + state_dict[new_key] = state_dict_old[key] + + del state_dict_old + gc.collect() + + print("Loading the checkpoint in a GotOcr2ForConditionalGeneration model.") + model = GotOcr2ForConditionalGeneration(config) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + model = model.to(torch.bfloat16) + print("model dtype:", model.dtype) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + + print("Saving the model.") + model.save_pretrained(model_path) + if push_to_hub: + model.push_to_hub("yonigozlan/GOT-OCR-2.0-hf", use_temp_dir=True) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + model = GotOcr2ForConditionalGeneration.from_pretrained(model_path, device_map="auto") + processor = GotOcr2Processor.from_pretrained(model_path) + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" + ) + + inputs = processor(image, return_tensors="pt", format=True).to(model.device, dtype=model.dtype) + generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) + decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True) + expected_output = "\\title{\nR" + print("Decoded output:", decoded_output) + assert decoded_output == expected_output + print("Model reloaded successfully.") + del model + + +class GotOcr2Converter(TikTokenConverter): + def __init__( + self, + vocab_file, + special_tokens: List[str], + pattern: str, + model_max_length: int, + chat_template: Optional[str] = None, + **kwargs, + ): + super().__init__(vocab_file, pattern=pattern) + self.additional_special_tokens = special_tokens + tokenizer = self.converted() + if chat_template is not None: + kwargs["chat_template"] = chat_template + self.tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + model_input_names=["input_ids", "attention_mask"], + model_max_length=model_max_length, + **kwargs, + ) + + +def write_tokenizer(tokenizer_path: str, save_dir: str, push_to_hub: bool = False): + model_max_length = CONTEXT_LENGTH + pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: W605 + # Special tokens + special_tokens = ( + ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] + + [f"<|extra_{i}|>" for i in range(205)] + + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + ] + ) + + pad_token = "<|endoftext|>" + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False, normalized=False, single_word=False) + + converter = GotOcr2Converter( + vocab_file=tokenizer_path, + pattern=pattern, + special_tokens=special_tokens, + model_max_length=model_max_length, + pad_token=pad_token, + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + clean_up_tokenization_spaces=True, + ) + tokenizer = converter.tokenizer + tokenizer.save_pretrained(save_dir) + + if push_to_hub: + tokenizer.push_to_hub("yonigozlan/GOT-OCR-2.0-hf", use_temp_dir=True) + + +def write_image_processor(save_dir: str, push_to_hub: bool = False): + image_processor = GotOcr2ImageProcessor( + do_resize=True, + size={"height": 1024, "width": 1024}, + do_rescale=True, + rescale_factor=1 / 255, + do_normalize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + + image_processor.save_pretrained(save_dir) + if push_to_hub: + image_processor.push_to_hub("yonigozlan/GOT-OCR-2.0-hf", use_temp_dir=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + default="stepfun-ai/GOT-OCR2_0", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--output_dir", + default="GotOcr2", + help="Location to write HF model and tokenizer", + ) + + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + args = parser.parse_args() + write_tokenizer( + tokenizer_path="qwen.tiktoken", + save_dir=args.output_dir, + push_to_hub=args.push_to_hub, + ) + + write_image_processor( + save_dir=args.output_dir, + push_to_hub=args.push_to_hub, + ) + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + push_to_hub=args.push_to_hub, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py new file mode 100644 index 00000000000000..2b3c2b63639879 --- /dev/null +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py @@ -0,0 +1,435 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_got_ocr2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + _rescale_for_pil_conversion, + convert_to_rgb, + resize, + to_channel_dimension_format, + to_pil_image, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_vision_available, + logging, +) + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +def find_best_patches_grid( + original_image_size: dict, + target_patch_size: dict, + min_patches: int, + max_patches: int, +) -> Tuple[int, int]: + """ + Given a minimum and maximum number of patches, find the patches grid with the closest aspect ratio to the + original image aspect ratio. + In case of tie-breaking condition when two grids have the same aspect ratio difference, we favor the grids with + more patches, until the area covered by the patches is more than twice the target area, in order to avoid unnecessarily + excessive patching. + """ + # compute possible patches grids + target_patches_grids = { + (i, j) + for n in range(min_patches, max_patches + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_patches and i * j >= min_patches + } + target_patches_grids = sorted(target_patches_grids, key=lambda x: x[0] * x[1]) + + # find the grid with the best aspect ratio + best_ratio_diff = float("inf") + best_grid = (1, 1) + original_width, original_height = original_image_size["width"], original_image_size["height"] + aspect_ratio = original_width / original_height + area = original_width * original_height + for grid in target_patches_grids: + grid_aspect_ratio = grid[0] / grid[1] + ratio_diff = abs(aspect_ratio - grid_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_grid = grid + elif ratio_diff == best_ratio_diff: + # if the aspect ratio difference is the same, we favor the grid with more patches + # until the area covered by the patches is more than twice the original image area + if area > 0.5 * target_patch_size["width"] * target_patch_size["height"] * grid[0] * grid[1]: + best_grid = grid + + return best_grid + + +class GotOcr2ImageProcessor(BaseImageProcessor): + r""" + Constructs a GOT_OCR2 image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 384, "width": 384} + size = get_size_dict(size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_convert_rgb: bool = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + + return encoded_outputs + + def crop_image_to_patches( + self, + image: ImageInput, + min_patches: int, + max_patches: int, + use_thumbnail: bool = True, + patch_size: Union[Tuple, int, dict] = None, + return_numpy: bool = False, + data_format: ChannelDimension = None, + ): + """ + Crop the image to patches and return a list of cropped images. + The number of patches and their grid arrangement are determined by the original image size, + the target patch size and the minimum and maximum number of patches. + The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio. + """ + patch_size = patch_size if patch_size is not None else self.size + patch_size = get_size_dict(patch_size, default_to_square=True) + original_size = get_size_dict(image.size, height_width_order=False) + do_rescale = False + if not isinstance(image, PIL.Image.Image): + do_rescale = _rescale_for_pil_conversion(image) + image = to_pil_image(image, do_rescale=do_rescale) + + # find the closest aspect ratio to the target + target_patches_grid = find_best_patches_grid(original_size, patch_size, min_patches, max_patches) + + # calculate the target width and height + patch_size_width, patch_size_height = patch_size["width"], patch_size["height"] + target_width = patch_size_width * target_patches_grid[0] + target_height = patch_size_height * target_patches_grid[1] + num_blocks = target_patches_grid[0] * target_patches_grid[1] + + # resize the image so that each patch is of patch_size + resized_image = image.resize((target_width, target_height)) + + # split the image into patches + processed_images = [] + num_columns = target_patches_grid[0] + for i in range(num_blocks): + column = i % num_columns + row = i // num_columns + box = ( + column * patch_size_width, + row * patch_size_height, + (column + 1) * patch_size_width, + (row + 1) * patch_size_height, + ) + # split the image + patch_image = resized_image.crop(box) + processed_images.append(patch_image) + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((patch_size_width, patch_size_height)) + processed_images.append(thumbnail_img) + + if return_numpy: + processed_images_numpy = [] + for processed_image in processed_images: + processed_image = np.array(processed_image) + # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image + # so we need to add it back if necessary. + processed_image = ( + np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image + ) + # The image is always in channels last format after converting from a PIL image + if data_format is not None: + processed_image = to_channel_dimension_format( + processed_image, data_format, input_channel_dim=ChannelDimension.LAST + ) + # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to + # rescale it back to the original range. + processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image + processed_images_numpy.append(processed_image) + processed_images = processed_images_numpy + + return processed_images + + +__all__ = ["GotOcr2ImageProcessor"] diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py new file mode 100644 index 00000000000000..8ab3805fd39ba2 --- /dev/null +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -0,0 +1,1757 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_got_ocr2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "GotOcr2Config" + + +class GotOcr2MLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + mlp_dim = config.mlp_dim if config.mlp_dim is not None else int(config.hidden_size * config.mlp_ratio) + self.lin1 = nn.Linear(config.hidden_size, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +class GotOcr2VisionAdapter(nn.Module): + def __init__(self, language_hidden_size: int, vision_output_channels: int): + super().__init__() + self.conv_upsampler1 = nn.Conv2d( + vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.conv_upsampler2 = nn.Conv2d( + vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False + ) + self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size) + + def forward(self, vision_embeddings): + x = self.conv_upsampler1(vision_embeddings) + x = self.conv_upsampler2(x) + x = x.flatten(2).permute(0, 2, 1) + x = self.multimodal_projector(x) + return x + + +@dataclass +class GotOcr2VisionEncoderOutput(ModelOutput): + """ + Base class for got_ocr2 vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class GotOcr2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class GotOcr2LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + self.normalized_shape = (normalized_shape,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.data_format == "channels_last": + x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + input_dtype = x.dtype + x = x.float() + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GotOcr2VisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def add_decomposed_rel_pos( + self, + attn: torch.Tensor, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + attn (`torch.Tensor`): + attention map. + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + attn (`torch.Tensor`): + attention map with added relative positional embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width) + attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width) + return attn + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +class GotOcr2VisionLayer(nn.Module): + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = GotOcr2VisionAttention(config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = GotOcr2MLPBlock(config) + self.window_size = window_size + + def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape( + batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] + ) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = windows.reshape( + batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 + ) + hidden_states = ( + hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) + ) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class GotOcr2VisionNeck(nn.Module): + def __init__(self, config: GotOcr2VisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = GotOcr2LayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = GotOcr2LayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +class GotOcr2VisionEncoder(nn.Module): + def __init__(self, config: GotOcr2VisionConfig): + super().__init__() + self.config = config + self.image_size = config.image_size + + self.patch_embed = GotOcr2PatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ) + ) + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = GotOcr2VisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = GotOcr2VisionNeck(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.patch_embed + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, GotOcr2VisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + ) + else: + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return GotOcr2VisionEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +GOT_OCR2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GotOcr2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare GotOcr2 Model outputting raw hidden-states without any specific head on top.", + GOT_OCR2_START_DOCSTRING, +) +class GotOcr2PreTrainedModel(PreTrainedModel): + config_class = GotOcr2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GotOcr2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class GotOcr2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GotOcr2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GotOcr2RotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[GotOcr2Config] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`GotOcr2RotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class GotOcr2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class GotOcr2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: GotOcr2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = GotOcr2RotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class GotOcr2FlashAttention2(GotOcr2Attention): + """ + GotOcr2 flash attention module, following GotOcr2 attention module. This module inherits from `GotOcr2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class GotOcr2SdpaAttention(GotOcr2Attention): + """ + GotOcr2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `GotOcr2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from GotOcr2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "GotOcr2Model is using GotOcr2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +GOT_OCR2_ATTENTION_CLASSES = { + "eager": GotOcr2Attention, + "flash_attention_2": GotOcr2FlashAttention2, + "sdpa": GotOcr2SdpaAttention, +} + + +class GotOcr2DecoderLayer(nn.Module): + def __init__(self, config: GotOcr2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = GOT_OCR2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = GotOcr2MLP(config) + self.input_layernorm = GotOcr2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GotOcr2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +GOT_OCR2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare GotOcr2 Model outputting raw hidden-states without any specific head on top.", + GOT_OCR2_START_DOCSTRING, +) +class GotOcr2Model(GotOcr2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GotOcr2DecoderLayer`] + + Args: + config: GotOcr2Config + """ + + def __init__(self, config: GotOcr2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GotOcr2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = GotOcr2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GotOcr2RotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: GotOcr2Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`GotOcr2Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.visual = GotOcr2VisionEncoder(config.vision_config) + self.model = GotOcr2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.visual_adapter = GotOcr2VisionAdapter(config.hidden_size, config.vision_config.output_channels) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + return model_kwargs + + @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, GotOcr2ForConditionalGeneration, TextStreamer + + >>> model = GotOcr2ForConditionalGeneration.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda") + >>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + + >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(image, return_tensors="pt", color="green").to("cuda") + + >>> # Generate + >>> streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) + >>> generate_ids = model.generate(**inputs, do_sample=False, + tokenizer = processor.tokenizer, + stop_strings='<|im_end|>', + streamer=streamer, + max_new_tokens=4096,) + + >>> outputs = processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1]:]) + "You should keep in mind what features from the module should be used, especially + when you're planning to sell a template." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.to(inputs_embeds.dtype) + image_embeds = self.visual(pixel_values) + image_embeds = self.visual_adapter(image_embeds.last_hidden_state) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] * image_embeds.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"] diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py new file mode 100644 index 00000000000000..342c5d01d51fba --- /dev/null +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -0,0 +1,923 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from transformers.models.blip.image_processing_blip import BlipImageProcessor +from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, Qwen2PreTrainedModel +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig +from transformers.models.sam.modeling_sam import SamVisionEncoder +from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from transformers.tokenization_utils_base import ( + PreTokenizedInput, + TextInput, +) + +from ...activations import ACT2FN +from ...cache_utils import StaticCache +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_transforms import ( + _rescale_for_pil_conversion, + to_channel_dimension_format, + to_pil_image, +) +from ...image_utils import ChannelDimension, ImageInput +from ...modeling_outputs import CausalLMOutputWithPast +from ...utils import ( + ModelOutput, + add_start_docstrings_to_model_forward, + is_vision_available, + logging, + replace_return_docstrings, +) + + +if is_vision_available(): + import PIL + + from ...image_utils import load_images + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "GotOcr2Config" + + +class GotOcr2VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GotOcr2VisionModel`]. It is used to instantiate a GOT_OCR2 + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of the SAM ViT-h + [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio * + hidden_size`. + """ + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + mlp_dim=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.mlp_dim = mlp_dim + + +class GotOcr2Config(Qwen2VLConfig): + pass + + +class GotOcr2TextKwargs(TextKwargs, total=False): + format: Optional[bool] + + +class GotOcr2ImagesKwargs(ImagesKwargs, total=False): + box: Optional[Union[List, Tuple[float, float], Tuple[float, float, float, float]]] + color: Optional[str] + num_image_tokens: Optional[int] + multi_page: Optional[bool] + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: GotOcr2TextKwargs + images_kwargs: GotOcr2ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "format": False, + }, + "images_kwargs": { + "num_image_tokens": 256, + "multi_page": False, + "crop_to_patches": False, + "min_patches": 1, + "max_patches": 6, + }, + } + + +def load_box_annotation(box: Union[List, Tuple], image_size: Tuple[int, int]) -> List: + """ + Load the box annotation and convert it to the format [x1, y1, x2, y2] in the range [0, 1000].""" + width, height = image_size + if len(box) == 4: + box[0] = int(box[0] / width * 1000) + box[1] = int(box[1] / height * 1000) + box[2] = int(box[2] / width * 1000) + box[3] = int(box[3] / height * 1000) + else: + raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].") + + return list(box) + + +def find_best_patches_grid( + original_image_size: dict, + target_patch_size: dict, + min_patches: int, + max_patches: int, +) -> Tuple[int, int]: + """ + Given a minimum and maximum number of patches, find the patches grid with the closest aspect ratio to the + original image aspect ratio. + In case of tie-breaking condition when two grids have the same aspect ratio difference, we favor the grids with + more patches, until the area covered by the patches is more than twice the target area, in order to avoid unnecessarily + excessive patching. + """ + # compute possible patches grids + target_patches_grids = { + (i, j) + for n in range(min_patches, max_patches + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_patches and i * j >= min_patches + } + target_patches_grids = sorted(target_patches_grids, key=lambda x: x[0] * x[1]) + + # find the grid with the best aspect ratio + best_ratio_diff = float("inf") + best_grid = (1, 1) + original_width, original_height = original_image_size["width"], original_image_size["height"] + aspect_ratio = original_width / original_height + area = original_width * original_height + for grid in target_patches_grids: + grid_aspect_ratio = grid[0] / grid[1] + ratio_diff = abs(aspect_ratio - grid_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_grid = grid + elif ratio_diff == best_ratio_diff: + # if the aspect ratio difference is the same, we favor the grid with more patches + # until the area covered by the patches is more than twice the original image area + if area > 0.5 * target_patch_size["width"] * target_patch_size["height"] * grid[0] * grid[1]: + best_grid = grid + + return best_grid + + +class GotOcr2ImageProcessor(BlipImageProcessor): + def crop_image_to_patches( + self, + image: ImageInput, + min_patches: int, + max_patches: int, + use_thumbnail: bool = True, + patch_size: Union[Tuple, int, dict] = None, + return_numpy: bool = False, + data_format: ChannelDimension = None, + ): + """ + Crop the image to patches and return a list of cropped images. + The number of patches and their grid arrangement are determined by the original image size, + the target patch size and the minimum and maximum number of patches. + The aspect ratio of the patches grid is chosen to be the closest to the original image aspect ratio. + """ + patch_size = patch_size if patch_size is not None else self.size + patch_size = get_size_dict(patch_size, default_to_square=True) + original_size = get_size_dict(image.size, height_width_order=False) + do_rescale = False + if not isinstance(image, PIL.Image.Image): + do_rescale = _rescale_for_pil_conversion(image) + image = to_pil_image(image, do_rescale=do_rescale) + + # find the closest aspect ratio to the target + target_patches_grid = find_best_patches_grid(original_size, patch_size, min_patches, max_patches) + + # calculate the target width and height + patch_size_width, patch_size_height = patch_size["width"], patch_size["height"] + target_width = patch_size_width * target_patches_grid[0] + target_height = patch_size_height * target_patches_grid[1] + num_blocks = target_patches_grid[0] * target_patches_grid[1] + + # resize the image so that each patch is of patch_size + resized_image = image.resize((target_width, target_height)) + + # split the image into patches + processed_images = [] + num_columns = target_patches_grid[0] + for i in range(num_blocks): + column = i % num_columns + row = i // num_columns + box = ( + column * patch_size_width, + row * patch_size_height, + (column + 1) * patch_size_width, + (row + 1) * patch_size_height, + ) + # split the image + patch_image = resized_image.crop(box) + processed_images.append(patch_image) + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((patch_size_width, patch_size_height)) + processed_images.append(thumbnail_img) + + if return_numpy: + processed_images_numpy = [] + for processed_image in processed_images: + processed_image = np.array(processed_image) + # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image + # so we need to add it back if necessary. + processed_image = ( + np.expand_dims(processed_image, axis=-1) if processed_image.ndim == 2 else processed_image + ) + # The image is always in channels last format after converting from a PIL image + if data_format is not None: + processed_image = to_channel_dimension_format( + processed_image, data_format, input_channel_dim=ChannelDimension.LAST + ) + # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to + # rescale it back to the original range. + processed_image = self.rescale(processed_image, 1 / 255) if do_rescale else processed_image + processed_images_numpy.append(processed_image) + processed_images = processed_images_numpy + + return processed_images + + +class GotOcr2Processor(ProcessorMixin): + r""" + Constructs a GotOcr2 processor which wraps a [`GotOcr2ImageProcessor`] and + [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and + tokenizer functionalities. See the [`~GotOcr2Processor.__call__`] and [`~GotOcr2Processor.decode`] for more information. + Args: + image_processor ([`GotOcr2ImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "GotOcr2ImageProcessor" + tokenizer_class = "PreTrainedTokenizerFast" + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + self.img_start_token = "" + self.img_end_token = "" + self.img_pad_token = "" + self.system_query = "system\nYou should follow the instructions carefully and explain your answers in detail." + + def _check_call_arguments(self, images, box, color, multi_page, crop_to_patches): + if images is None: + raise ValueError("Images are required to be passed to the processor.") + + if not isinstance(box, (list, tuple)): + raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].") + + if multi_page or crop_to_patches: + if multi_page and crop_to_patches: + raise ValueError("Cannot set both `multi_page` and `crop_to_patches` to `True`.") + if box[0] is not None or color is not None: + raise ValueError("Cannot pass `box` or `color` with multi-page inference.") + + if box[0] is not None and color is not None: + raise ValueError("Both `box` and `color` cannot be set at the same time.") + + def _make_list_of_inputs(self, images, text, box, color, multi_page): + if not isinstance(images, (list, tuple)): + if multi_page: + logger.warning("Multi-page inference is enabled but only one image is passed.") + images = [images] + elif isinstance(images[0], (list, tuple)) and not multi_page: + raise ValueError("Nested images are only supported with `multi_page` set to `True`.") + elif not isinstance(images[0], (list, tuple)) and multi_page: + images = [images] + + if text is not None: + if not isinstance(text, (list, tuple)): + text = [text] + if len(text) != len(images): + raise ValueError("The number of `text` must match the number of images.") + + if not isinstance(box[0], (list, tuple)): + # Use the same box for all images + box = [box for _ in range(len(images))] + if not isinstance(color, (list, tuple)): + color = [color for _ in range(len(images))] + if len(box) != len(images): + raise ValueError("The number of `box` must match the number of images.") + if len(color) != len(images): + raise ValueError("The number of `color` must match the number of images.") + + return images, text, box, color + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[GotOcr2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text` + is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and + `crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to + GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + format (`bool`, *optional*): + If set, will add the format token to the query, and the model will return the OCR result with formatting. + box (`List[float]`, `List[Tuple[float, float]]`, `List[Tuple[float, float, float, float]]`, *optional*): + The box annotation to be added to the query. If a list of floats or a tuple of floats is provided, it + will be interpreted as [x1, y1, x2, y2]. If a list of tuples is provided, each tuple should be in the + form (x1, y1, x2, y2). + color (`str`, *optional*): + The color annotation to be added to the query. The model will return the OCR result within the box with + the specified color. + multi_page (`bool`, *optional*): + If set, will enable multi-page inference. The model will return the OCR result across multiple pages. + crop_to_patches (`bool`, *optional*): + If set, will crop the image to patches. The model will return the OCR result upon the patch reference. + min_patches (`int`, *optional*): + The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to + `True`. + max_patches (`int`, *optional*): + The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to + `True`. + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + output_kwargs = self._merge_kwargs( + GotOcr2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + format = output_kwargs["text_kwargs"].pop("format") + num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens") + box = output_kwargs["images_kwargs"].pop("box", [None]) + color = output_kwargs["images_kwargs"].pop("color", None) + multi_page = output_kwargs["images_kwargs"].pop("multi_page") + crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches") + min_patches = output_kwargs["images_kwargs"].pop("min_patches") + max_patches = output_kwargs["images_kwargs"].pop("max_patches") + + self._check_call_arguments(images, box, color, multi_page, crop_to_patches) + images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page) + + # Load images as we need to know the image size + images = load_images(images) + if text is None: + text = [] + for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)): + if crop_to_patches: + image_group = self.image_processor.crop_image_to_patches( + image_group, + patch_size=output_kwargs["images_kwargs"].get("size"), + min_patches=min_patches, + max_patches=max_patches, + ) + images[index] = image_group + num_images = len(image_group) if (multi_page or crop_to_patches) else 1 + if box_single[0] is not None: + box_single = load_box_annotation(box_single, image_group.size) + query = ( + f"{f'[{color_single}] ' if color_single is not None else ''}" + f"{str(box_single) if box_single[0] is not None else ''} " + "OCR" + f"{' with format' if format else ''}" + f"{' across multi pages' if multi_page else ''}" + f"{' upon the patch reference' if crop_to_patches else ''}" + ": " + ) + prompt = ( + "<|im_start|>" + + self.system_query + + "<|im_end|>" + + "<|im_start|>user\n" + + self.img_start_token + + self.img_pad_token * num_image_tokens * num_images + + self.img_end_token + + "\n" + + query + + "<|im_end|><|im_start|>assistant\n" + ) + text.append(prompt) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + if multi_page or crop_to_patches: + # flatten images + images = [image for image_group in images for image in image_group] + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +class GotOcr2MLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + mlp_dim = config.mlp_dim if config.mlp_dim is not None else int(config.hidden_size * config.mlp_ratio) + self.lin1 = nn.Linear(config.hidden_size, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +class GotOcr2VisionAdapter(nn.Module): + def __init__(self, language_hidden_size: int, vision_output_channels: int): + super().__init__() + self.conv_upsampler1 = nn.Conv2d( + vision_output_channels, vision_output_channels * 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.conv_upsampler2 = nn.Conv2d( + vision_output_channels * 2, language_hidden_size, kernel_size=3, stride=2, padding=1, bias=False + ) + self.multimodal_projector = nn.Linear(language_hidden_size, language_hidden_size) + + def forward(self, vision_embeddings): + x = self.conv_upsampler1(vision_embeddings) + x = self.conv_upsampler2(x) + x = x.flatten(2).permute(0, 2, 1) + x = self.multimodal_projector(x) + return x + + +class GotOcr2VisionEncoder(SamVisionEncoder): + pass + + +class GotOcr2PreTrainedModel(Qwen2PreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class GotOcr2Model(Qwen2Model): + pass + + +GOT_OCR2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`GotOcr2Processor`] uses + [`GotOcr2ImageProcessor`] for processing images. +""" + + +class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.visual = GotOcr2VisionEncoder(config.vision_config) + self.model = GotOcr2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.visual_adapter = GotOcr2VisionAdapter(config.hidden_size, config.vision_config.output_channels) + + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + return model_kwargs + + @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, GotOcr2ForConditionalGeneration, TextStreamer + + >>> model = GotOcr2ForConditionalGeneration.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda") + >>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + + >>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(image, return_tensors="pt", color="green").to("cuda") + + >>> # Generate + >>> streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) + >>> generate_ids = model.generate(**inputs, do_sample=False, + tokenizer = processor.tokenizer, + stop_strings='<|im_end|>', + streamer=streamer, + max_new_tokens=4096,) + + >>> outputs = processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1]:]) + "You should keep in mind what features from the module should be used, especially + when you're planning to sell a template." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.to(inputs_embeds.dtype) + image_embeds = self.visual(pixel_values) + image_embeds = self.visual_adapter(image_embeds.last_hidden_state) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] * image_embeds.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = [ + "GotOcr2VisionConfig", + "GotOcr2Config", + "GotOcr2Processor", + "GotOcr2PreTrainedModel", + "GotOcr2Model", + "GotOcr2ForConditionalGeneration", + "GotOcr2ImageProcessor", +] diff --git a/src/transformers/models/got_ocr2/processing_got_ocr2.py b/src/transformers/models/got_ocr2/processing_got_ocr2.py new file mode 100644 index 00000000000000..57eae3acd62013 --- /dev/null +++ b/src/transformers/models/got_ocr2/processing_got_ocr2.py @@ -0,0 +1,302 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/got_ocr2/modular_got_ocr2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_got_ocr2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...utils import is_vision_available, logging + + +if is_vision_available(): + from ...image_utils import load_images + +logger = logging.get_logger(__name__) + + +class GotOcr2TextKwargs(TextKwargs, total=False): + format: Optional[bool] + + +class GotOcr2ImagesKwargs(ImagesKwargs, total=False): + box: Optional[Union[List, Tuple[float, float], Tuple[float, float, float, float]]] + color: Optional[str] + num_image_tokens: Optional[int] + multi_page: Optional[bool] + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +class GotOcr2ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: GotOcr2TextKwargs + images_kwargs: GotOcr2ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "format": False, + }, + "images_kwargs": { + "num_image_tokens": 256, + "multi_page": False, + "crop_to_patches": False, + "min_patches": 1, + "max_patches": 6, + }, + } + + +def load_box_annotation(box: Union[List, Tuple], image_size: Tuple[int, int]) -> List: + """ + Load the box annotation and convert it to the format [x1, y1, x2, y2] in the range [0, 1000].""" + width, height = image_size + if len(box) == 4: + box[0] = int(box[0] / width * 1000) + box[1] = int(box[1] / height * 1000) + box[2] = int(box[2] / width * 1000) + box[3] = int(box[3] / height * 1000) + else: + raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].") + + return list(box) + + +class GotOcr2Processor(ProcessorMixin): + r""" + Constructs a GotOcr2 processor which wraps a [`GotOcr2ImageProcessor`] and + [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and + tokenizer functionalities. See the [`~GotOcr2Processor.__call__`] and [`~GotOcr2Processor.decode`] for more information. + Args: + image_processor ([`GotOcr2ImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "GotOcr2ImageProcessor" + tokenizer_class = "PreTrainedTokenizerFast" + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + self.img_start_token = "" + self.img_end_token = "" + self.img_pad_token = "" + self.system_query = "system\nYou should follow the instructions carefully and explain your answers in detail." + + def _check_call_arguments(self, images, box, color, multi_page, crop_to_patches): + if images is None: + raise ValueError("Images are required to be passed to the processor.") + + if not isinstance(box, (list, tuple)): + raise ValueError("Box must be a list or tuple of lists in the form [x1, y1, x2, y2].") + + if multi_page or crop_to_patches: + if multi_page and crop_to_patches: + raise ValueError("Cannot set both `multi_page` and `crop_to_patches` to `True`.") + if box[0] is not None or color is not None: + raise ValueError("Cannot pass `box` or `color` with multi-page inference.") + + if box[0] is not None and color is not None: + raise ValueError("Both `box` and `color` cannot be set at the same time.") + + def _make_list_of_inputs(self, images, text, box, color, multi_page): + if not isinstance(images, (list, tuple)): + if multi_page: + logger.warning("Multi-page inference is enabled but only one image is passed.") + images = [images] + elif isinstance(images[0], (list, tuple)) and not multi_page: + raise ValueError("Nested images are only supported with `multi_page` set to `True`.") + elif not isinstance(images[0], (list, tuple)) and multi_page: + images = [images] + + if text is not None: + if not isinstance(text, (list, tuple)): + text = [text] + if len(text) != len(images): + raise ValueError("The number of `text` must match the number of images.") + + if not isinstance(box[0], (list, tuple)): + # Use the same box for all images + box = [box for _ in range(len(images))] + if not isinstance(color, (list, tuple)): + color = [color for _ in range(len(images))] + if len(box) != len(images): + raise ValueError("The number of `box` must match the number of images.") + if len(color) != len(images): + raise ValueError("The number of `color` must match the number of images.") + + return images, text, box, color + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[GotOcr2ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text` + is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and + `crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwrags` arguments to + GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + format (`bool`, *optional*): + If set, will add the format token to the query, and the model will return the OCR result with formatting. + box (`List[float]`, `List[Tuple[float, float]]`, `List[Tuple[float, float, float, float]]`, *optional*): + The box annotation to be added to the query. If a list of floats or a tuple of floats is provided, it + will be interpreted as [x1, y1, x2, y2]. If a list of tuples is provided, each tuple should be in the + form (x1, y1, x2, y2). + color (`str`, *optional*): + The color annotation to be added to the query. The model will return the OCR result within the box with + the specified color. + multi_page (`bool`, *optional*): + If set, will enable multi-page inference. The model will return the OCR result across multiple pages. + crop_to_patches (`bool`, *optional*): + If set, will crop the image to patches. The model will return the OCR result upon the patch reference. + min_patches (`int`, *optional*): + The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to + `True`. + max_patches (`int`, *optional*): + The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to + `True`. + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + output_kwargs = self._merge_kwargs( + GotOcr2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + format = output_kwargs["text_kwargs"].pop("format") + num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens") + box = output_kwargs["images_kwargs"].pop("box", [None]) + color = output_kwargs["images_kwargs"].pop("color", None) + multi_page = output_kwargs["images_kwargs"].pop("multi_page") + crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches") + min_patches = output_kwargs["images_kwargs"].pop("min_patches") + max_patches = output_kwargs["images_kwargs"].pop("max_patches") + + self._check_call_arguments(images, box, color, multi_page, crop_to_patches) + images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page) + + # Load images as we need to know the image size + images = load_images(images) + if text is None: + text = [] + for index, (image_group, box_single, color_single) in enumerate(zip(images, box, color)): + if crop_to_patches: + image_group = self.image_processor.crop_image_to_patches( + image_group, + patch_size=output_kwargs["images_kwargs"].get("size"), + min_patches=min_patches, + max_patches=max_patches, + ) + images[index] = image_group + num_images = len(image_group) if (multi_page or crop_to_patches) else 1 + if box_single[0] is not None: + box_single = load_box_annotation(box_single, image_group.size) + query = ( + f"{f'[{color_single}] ' if color_single is not None else ''}" + f"{str(box_single) if box_single[0] is not None else ''} " + "OCR" + f"{' with format' if format else ''}" + f"{' across multi pages' if multi_page else ''}" + f"{' upon the patch reference' if crop_to_patches else ''}" + ": " + ) + prompt = ( + "<|im_start|>" + + self.system_query + + "<|im_end|>" + + "<|im_start|>user\n" + + self.img_start_token + + self.img_pad_token * num_image_tokens * num_images + + self.img_end_token + + "\n" + + query + + "<|im_end|><|im_start|>assistant\n" + ) + text.append(prompt) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + if multi_page or crop_to_patches: + # flatten images + images = [image for image_group in images for image in image_group] + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["GotOcr2Processor"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c6057088b7d506..16dec56c50bb54 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4515,6 +4515,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class GotOcr2ForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GotOcr2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GotOcr2PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GPT2DoubleHeadsModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 3ebda4404aae9c..db7b1b9f2e66b8 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -282,6 +282,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class GotOcr2ImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class GroundingDinoImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index bf56578a164c94..4143e5fafbf527 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1613,7 +1613,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams): # There are a few exception patterns in this test: # 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed requires_inputs_ids = any( - model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"] + model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl", "gotocr2"] ) # 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex # than calling the embedding layer with `input_ids`. Subcases of this exception: @@ -1633,7 +1633,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams): inputs_dict.pop("pixel_values_images", None) # 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds` has_complex_embeds_computation = any( - model_name in model_class.__name__.lower() for model_name in ["moshi", "qwen2vl"] + model_name in model_class.__name__.lower() for model_name in ["moshi", "qwen2vl", "gotocr2"] ) # 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate, # we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input. diff --git a/tests/models/got_ocr2/__init__.py b/tests/models/got_ocr2/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/got_ocr2/test_image_processing_got_ocr2.py b/tests/models/got_ocr2/test_image_processing_got_ocr2.py new file mode 100644 index 00000000000000..c4e75feee660db --- /dev/null +++ b/tests/models/got_ocr2/test_image_processing_got_ocr2.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_vision_available(): + from transformers import GotOcr2ImageProcessor + + +class GotOcr2ImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_normalize=True, + do_pad=False, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + do_convert_rgb=True, + ): + super().__init__() + size = size if size is not None else {"height": 20, "width": 20} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_pad = do_pad + self.do_convert_rgb = do_convert_rgb + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + "do_pad": self.do_pad, + } + + def expected_output_image_shape(self, images): + return self.num_channels, self.size["height"], self.size["width"] + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = GotOcr2ImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = GotOcr2ImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processor = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "size")) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "image_mean")) + self.assertTrue(hasattr(image_processor, "image_std")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + + def test_crop_to_patches(self): + image_processor = self.image_processing_class(**self.image_processor_dict) + image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0] + processed_images = image_processor.crop_image_to_patches(image, 1, 6, use_thumbnail=True) + self.assertEqual(len(processed_images), 5) + self.assertEqual(processed_images[0].size, (20, 20)) diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py new file mode 100644 index 00000000000000..beccbab74fae9c --- /dev/null +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -0,0 +1,352 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch GotOcr2 model.""" + +import unittest + +from transformers import ( + AutoProcessor, + GotOcr2Config, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import cleanup, require_torch, slow, torch_device + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + GotOcr2ForConditionalGeneration, + ) + + +if is_vision_available(): + from transformers.image_utils import load_image + + +class GotOcr2VisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=3, + seq_length=7, + num_channels=3, + ignore_index=-100, + image_size=64, + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + image_token_id=1, + hidden_act="silu", + hidden_size=128, + vocab_size=99, + intermediate_size=37, + max_position_embeddings=512, + max_window_layers=3, + model_type="got_ocr2", + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rope_theta=10000, + tie_word_embeddings=True, + is_training=True, + vision_config={ + "num_hidden_layers": 2, + "output_channels": 64, + "hidden_act": "quick_gelu", + "hidden_size": 32, + "mlp_ratio": 4, + "num_attention_heads": 4, + "patch_size": 2, + "image_size": 64, + }, + rope_scaling={"type": "mrope", "mrope_section": [2, 1, 1]}, + ): + self.parent = parent + self.ignore_index = ignore_index + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.image_token_id = image_token_id + self.hidden_act = hidden_act + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.max_window_layers = max_window_layers + self.model_type = model_type + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_key_value_heads = num_key_value_heads + self.rope_theta = rope_theta + self.tie_word_embeddings = tie_word_embeddings + self.vision_config = vision_config + self.rope_scaling = rope_scaling + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.is_training = is_training + self.vocab_size = vocab_size + self.num_image_tokens = 64 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return GotOcr2Config( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + hidden_act=self.hidden_act, + max_position_embeddings=self.max_position_embeddings, + vision_config=self.vision_config, + model_type=self.model_type, + max_window_layers=self.max_window_layers, + rope_scaling=self.rope_scaling, + tie_word_embeddings=self.tie_word_embeddings, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + image_token_id=self.image_token_id, + vocab_size=self.vocab_size, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + # input_ids[:, -1] = self.pad_token_id + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[:, : self.num_image_tokens] = self.image_token_id + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def create_and_check_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask): + model = GotOcr2ForConditionalGeneration(config=config) + model.to(torch_device) + model.half() + model.eval() + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + def create_and_check_model_fp16_autocast_forward(self, config, input_ids, pixel_values, attention_mask): + config.torch_dtype = torch.float16 + model = GotOcr2ForConditionalGeneration(config=config) + model.to(torch_device) + model.eval() + with torch.autocast(device_type="cuda", dtype=torch.float16): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + +@require_torch +class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (GotOcr2ForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (GotOcr2ForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "image-to-text": GotOcr2ForConditionalGeneration, + "image-text-to-text": GotOcr2ForConditionalGeneration, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + + def setUp(self): + self.model_tester = GotOcr2VisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=GotOcr2Config, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @unittest.skip( + reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs" + ) + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip( + reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format" + ) + def test_past_key_values_format(self): + pass + + @unittest.skip( + reason="GotOcr2 needs a dynamic control flow to pass pixel values to the forward function only in the first generation step" + ) + def test_generate_compile_1_end_to_end(self): + pass + + +@require_torch +class GotOcr2IntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_small_model_integration_test_got_ocr_stop_strings(self): + model_id = "yonigozlan/GOT-OCR-2.0-hf" + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_ocr/resolve/main/iam_picture.jpeg" + ) + + inputs = self.processor(image, return_tensors="pt").to(torch_device) + generate_ids = model.generate( + **inputs, + do_sample=False, + num_beams=1, + tokenizer=self.processor.tokenizer, + stop_strings="<|im_end|>", + max_new_tokens=4096, + ) + decoded_output = self.processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + expected_output = "industre" + self.assertEqual(decoded_output, expected_output) + + @slow + def test_small_model_integration_test_got_ocr_format(self): + model_id = "yonigozlan/GOT-OCR-2.0-hf" + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" + ) + + inputs = self.processor(image, return_tensors="pt", format=True).to(torch_device) + generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) + decoded_output = self.processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + expected_output = "\\title{\nR" + self.assertEqual(decoded_output, expected_output) + + @slow + def test_small_model_integration_test_got_ocr_fine_grained(self): + model_id = "yonigozlan/GOT-OCR-2.0-hf" + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" + ) + + inputs = self.processor(image, return_tensors="pt", color="green").to(torch_device) + generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) + decoded_output = self.processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + expected_output = "You should keep in" + self.assertEqual(decoded_output, expected_output) + + @slow + def test_small_model_integration_test_got_ocr_crop_to_patches(self): + model_id = "yonigozlan/GOT-OCR-2.0-hf" + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png" + ) + + inputs = self.processor(image, return_tensors="pt", crop_to_patches=True).to(torch_device) + generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) + decoded_output = self.processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + expected_output = "on developing architectural improvements" + self.assertEqual(decoded_output, expected_output) + + @slow + def test_small_model_integration_test_got_ocr_multi_pages(self): + model_id = "yonigozlan/GOT-OCR-2.0-hf" + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + image1 = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png" + ) + image2 = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" + ) + + inputs = self.processor([image1, image2], return_tensors="pt", multi_page=True).to(torch_device) + generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) + decoded_output = self.processor.decode( + generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + expected_output = "on developing architectural improvements" + self.assertEqual(decoded_output, expected_output) + + @slow + def test_small_model_integration_test_got_ocr_batched(self): + model_id = "yonigozlan/GOT-OCR-2.0-hf" + model = GotOcr2ForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + image1 = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png" + ) + image2 = load_image( + "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg" + ) + + inputs = self.processor([image1, image2], return_tensors="pt").to(torch_device) + generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4) + decoded_output = self.processor.batch_decode( + generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + expected_output = ["Reducing the number", "R&D QUALITY"] + self.assertEqual(decoded_output, expected_output) diff --git a/tests/models/got_ocr2/test_processor_got_ocr2.py b/tests/models/got_ocr2/test_processor_got_ocr2.py new file mode 100644 index 00000000000000..9a34c6404ba590 --- /dev/null +++ b/tests/models/got_ocr2/test_processor_got_ocr2.py @@ -0,0 +1,77 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +from transformers import AutoProcessor, GotOcr2Processor, PreTrainedTokenizerFast +from transformers.testing_utils import require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import GotOcr2ImageProcessor + + +@require_vision +class GotOcr2ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = GotOcr2Processor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + image_processor = GotOcr2ImageProcessor() + tokenizer = PreTrainedTokenizerFast.from_pretrained("yonigozlan/GOT-OCR-2.0-hf") + processor_kwargs = self.prepare_processor_dict() + processor = GotOcr2Processor(image_processor, tokenizer, **processor_kwargs) + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_ocr_queries(self): + processor = self.get_processor() + image_input = self.prepare_image_inputs() + inputs = processor(image_input, return_tensors="pt") + self.assertEqual(inputs["input_ids"].shape, (1, 286)) + self.assertEqual(inputs["pixel_values"].shape, (1, 3, 384, 384)) + + inputs = processor(image_input, return_tensors="pt", format=True) + self.assertEqual(inputs["input_ids"].shape, (1, 288)) + self.assertEqual(inputs["pixel_values"].shape, (1, 3, 384, 384)) + + inputs = processor(image_input, return_tensors="pt", color="red") + self.assertEqual(inputs["input_ids"].shape, (1, 290)) + self.assertEqual(inputs["pixel_values"].shape, (1, 3, 384, 384)) + + inputs = processor(image_input, return_tensors="pt", box=[0, 0, 100, 100]) + self.assertEqual(inputs["input_ids"].shape, (1, 303)) + self.assertEqual(inputs["pixel_values"].shape, (1, 3, 384, 384)) + + inputs = processor([image_input, image_input], return_tensors="pt", multi_page=True, format=True) + self.assertEqual(inputs["input_ids"].shape, (1, 547)) + self.assertEqual(inputs["pixel_values"].shape, (2, 3, 384, 384)) + + inputs = processor(image_input, return_tensors="pt", crop_to_patches=True) + self.assertEqual(inputs["input_ids"].shape, (1, 1826)) + self.assertEqual(inputs["pixel_values"].shape, (7, 3, 384, 384)) diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 93ed33ae774458..9ce86950d5f391 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -328,7 +328,7 @@ def test_beam_search_low_memory(self): pass @unittest.skip( - reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs" + reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs" ) def test_generate_from_inputs_embeds_with_static_cache(self): pass diff --git a/utils/check_repo.py b/utils/check_repo.py index 3dbe59f192293a..e2c8b1b54f0fc3 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -139,6 +139,7 @@ "Qwen2VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2VLForConditionalGeneration. "MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests "MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests + "GotOcr2Model", # Building part of bigger (tested) model. Tested implicitly through GotOcr2ForConditionalGeneration. ] )