Skip to content

Commit

Permalink
Add model doc
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Nov 25, 2024
1 parent a1556dd commit f8e1ac9
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 41 deletions.
221 changes: 214 additions & 7 deletions docs/source/en/model_doc/got_ocr2.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<!--Copyright 2024 The Qwen Team and The HuggingFace Team. All rights reserved.
<!--Copyright 2024 StepFun and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
Expand All @@ -18,20 +18,227 @@ rendered properly in your Markdown viewer.

## Overview

The GOT-OCR2 model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The GOT-OCR2 model was proposed in [General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model](https://arxiv.org/abs/2409.01704) by Haoran Wei, Chenglong Liu, Jinyue Chen, Jia Wang, Lingyu Kong, Yanming Xu, Zheng Ge, Liang Zhao, Jianjian Sun, Yuang Peng, Chunrui Han, Xiangyu Zhang.

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*
*Traditional OCR systems (OCR-1.0) are increasingly unable to meet people’snusage due to the growing demand for intelligent processing of man-made opticalncharacters. In this paper, we collectively refer to all artificial optical signals (e.g., plain texts, math/molecular formulas, tables, charts, sheet music, and even geometric shapes) as "characters" and propose the General OCR Theory along with an excellent model, namely GOT, to promote the arrival of OCR-2.0. The GOT, with 580M parameters, is a unified, elegant, and end-to-end model, consisting of a high-compression encoder and a long-contexts decoder. As an OCR-2.0 model, GOT can handle all the above "characters" under various OCR tasks. On the input side, the model supports commonly used scene- and document-style images in slice and whole-page styles. On the output side, GOT can generate plain or formatted results (markdown/tikz/smiles/kern) via an easy prompt. Besides, the model enjoys interactive OCR features, i.e., region-level recognition guided by coordinates or colors. Furthermore, we also adapt dynamic resolution and multipage OCR technologies to GOT for better practicality. In experiments, we provide sufficient results to prove the superiority of our model.*

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/got_ocr_overview.png"
alt="drawing" width="600"/>

<small> GOT-OCR2 training stages. Taken from the <a href="https://arxiv.org/abs/2409.01704">original paper.</a> </small>


Tips:

<INSERT TIPS ABOUT MODEL HERE>
GOT-OCR2 works on a wide range of tasks, including plain document OCR, scene text OCR, formatted document OCR, and even OCR for tables, charts, mathematical formulas, geometric shapes, molecular formulas and sheet music. While this implementation of the model will only output plain text, the outputs can be further processed to render the desired format, with packages like `pdftex`, `mathpix`, `matplotlib`, `tikz`, `verovio` or `pyecharts`.
The model can also be used for interactive OCR, where the user can specify the region to be recognized by providing the coordinates or the color of the region's bounding box.

This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2.0).

## Usage example

### Plain text inference

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda")
>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
>>> inputs = processor(image, return_tensors="pt").to("cuda", dtype=model.dtype)

>>> generate_ids = model.generate(
>>> **inputs,
>>> do_sample=False,
>>> tokenizer=processor.tokenizer,
>>> stop_strings="<|im_end|>",
>>> max_new_tokens=4096,
>>> )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"R&D QUALITY IMPROVEMENT\nSUGGESTION/SOLUTION FORM\nName/Phone Ext. : (...)"
```

### Plain text inference batched

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf")
>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf")

>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"

>>> inputs = processor([image1, image2], return_tensors="pt")

>>> generate_ids = model.generate(
>>> **inputs,
>>> do_sample=False,
>>> tokenizer=processor.tokenizer,
>>> stop_strings="<|im_end|>",
>>> max_new_tokens=4,
>>> )

>>> processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
["Reducing the number", "R&D QUALITY"]
```

### Formatted text inference

GOT-OCR2 can also generate formatted text, such as markdown or LaTeX. Here is an example of how to generate formatted text:

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda")
>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png"
>>> inputs = processor(image, return_tensors="pt", format=True).to("cuda", dtype=model.dtype)

>>> generate_ids = model.generate(
>>> **inputs,
>>> do_sample=False,
>>> tokenizer=processor.tokenizer,
>>> stop_strings="<|im_end|>",
>>> max_new_tokens=4096,
>>> )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"\\author{\nHanwen Jiang* \\(\\quad\\) Arjun Karpur \\({ }^{\\dagger} \\quad\\) Bingyi Cao \\({ }^{\\dagger} \\quad\\) (...)"
```

### Inference on multiple pages

Although it might be reasonable in most cases to use a “for loop” for multi-page processing, some text data with formatting across several pages make it necessary to process all pages at once. GOT introduces a multi-page OCR (without “for loop”) feature, where multiple pages can be processed by the model at once, whith the output being one continuous text.
Here is an example of how to process multiple pages at once:


```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda")
>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf")

>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png"
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png"
>>> inputs = processor([image1, image2], return_tensors="pt", format=True).to("cuda", dtype=model.dtype)

>>> generate_ids = model.generate(
>>> **inputs,
>>> do_sample=False,
>>> tokenizer=processor.tokenizer,
>>> stop_strings="<|im_end|>",
>>> max_new_tokens=4096,
>>> )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"\\title{\nGeneral OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model\n}\n\\author{\nHaoran Wei (...)"
```

### Inference on cropped patches

GOT supports a 1024×1024 input resolution, which is sufficient for most OCR tasks, such as scene OCR or processing A4-sized PDF pages. However, certain scenarios, like horizontally stitched two-page PDFs commonly found in academic papers or images with unusual aspect ratios, can lead to accuracy issues when processed as a single image. To address this, GOT can dynamically crop an image into patches, process them all at once, and merge the results for better accuracy with such inputs.
Here is an example of how to process cropped patches:

```python
>>> import torch
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16).to("cuda")
>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png"
>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3).to("cuda", dtype=model.dtype)

>>> generate_ids = model.generate(
>>> **inputs,
>>> do_sample=False,
>>> tokenizer=processor.tokenizer,
>>> stop_strings="<|im_end|>",
>>> max_new_tokens=4096,
>>> )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"on developing architectural improvements to make learnable matching methods generalize.\nMotivated by the above observations, (...)"
```

### Inference on a specific region

GOT supports interactive OCR, where the user can specify the region to be recognized by providing the coordinates or the color of the region's bounding box. Here is an example of how to process a specific region:

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda")
>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
>>> inputs = processor(image, return_tensors="pt", color="green") # or box=[x1, y1, x2, y2] for coordinates (image pixels)
>>> inputs = inputs.to("cuda", dtype=model.dtype)

>>> generate_ids = model.generate(
>>> **inputs,
>>> do_sample=False,
>>> tokenizer=processor.tokenizer,
>>> stop_strings="<|im_end|>",
>>> max_new_tokens=4096,
>>> )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"You should keep in mind what features from the module should be used, especially \nwhen you’re planning to sell a template."
```

### Inference on general OCR data example: sheet music

Although this implementation of the model will only output plain text, the outputs can be further processed to render the desired format, with packages like `pdftex`, `mathpix`, `matplotlib`, `tikz`, `verovio` or `pyecharts`.
Here is an example of how to process sheet music:

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import verovio

>>> model = AutoModelForImageTextToText.from_pretrained("yonigozlan/GOT-OCR-2.0-hf").to("cuda")
>>> processor = AutoProcessor.from_pretrained("yonigozlan/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png"
>>> inputs = processor(image, return_tensors="pt", format=True).to("cuda", dtype=model.dtype)

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
>>> generate_ids = model.generate(
>>> **inputs,
>>> do_sample=False,
>>> tokenizer=processor.tokenizer,
>>> stop_strings="<|im_end|>",
>>> max_new_tokens=4096,
>>> )

>>> outputs = processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
>>> tk = verovio.toolkit()
>>> tk.loadData(outputs)
>>> tk.setOptions(
>>> {
>>> "pageWidth": 2100,
>>> "pageHeight": 800,
>>> "footer": "none",
>>> "barLineWidth": 0.5,
>>> "beamMaxSlope": 15,
>>> "staffLineWidth": 0.2,
>>> "spacingStaff": 6,
>>> }
>>> )
>>> tk.getPageCount()
>>> svg = tk.renderToSVG()
>>> svg = svg.replace('overflow="inherit"', 'overflow="visible"')
>>> with open("output.svg", "w") as f:
>>> f.write(svg)
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sheet_music.svg"
alt="drawing" width="600"/>

## GotOcr2Config

Expand Down
18 changes: 9 additions & 9 deletions docs/source/en/model_doc/qwen2_vl.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ rendered properly in your Markdown viewer.

## Overview

The [Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/) model is a major update to [Qwen-VL](https://arxiv.org/pdf/2308.12966) from the Qwen team at Alibaba Research.
The [Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/) model is a major update to [Qwen-VL](https://arxiv.org/pdf/2308.12966) from the Qwen team at Alibaba Research.

The abstract from the blog is the following:

Expand Down Expand Up @@ -231,7 +231,7 @@ In case of limited GPU RAM, one can reduce the resolution as follows:

```python
min_pixels = 256*28*28
max_pixels = 1024*28*28
max_pixels = 1024*28*28
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
```
This ensures each image gets encoded using a number between 256-1024 tokens. The 28 comes from the fact that the model uses a patch size of 14 and a temporal patch size of 2 (14 x 2 = 28).
Expand All @@ -245,7 +245,7 @@ conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "text", "text": "Hello, how are you?"}
]
},
Expand All @@ -256,10 +256,10 @@ conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Can you describe these images and video?"},
{"type": "image"},
{"type": "image"},
{"type": "video"},
{"type": "text", "text": "Can you describe these images and video?"},
{"type": "image"},
{"type": "image"},
{"type": "video"},
{"type": "text", "text": "These are from my vacation."}
]
},
Expand Down Expand Up @@ -300,8 +300,8 @@ To load and run a model using Flash Attention-2, simply add `attn_implementation
from transformers import Qwen2VLForConditionalGeneration

model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
```
Expand Down
41 changes: 30 additions & 11 deletions src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@
GotOcr2Config,
GotOcr2ForConditionalGeneration,
GotOcr2ImageProcessor,
GotOcr2Processor,
PreTrainedTokenizerFast,
is_vision_available,
)
from transformers.convert_slow_tokenizer import TikTokenConverter
from transformers.tokenization_utils import AddedToken


if is_vision_available():
from transformers.image_utils import load_image


# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Vision encoder mapping
Expand Down Expand Up @@ -142,20 +148,34 @@ def write_model(
print("Loading the checkpoint in a GotOcr2ForConditionalGeneration model.")
model = GotOcr2ForConditionalGeneration(config)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
model = model.to(torch.bfloat16)
print("model dtype:", model.dtype)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)

print("Saving the model.")
model.save_pretrained(model_path)
if push_to_hub:
model.push_to_hub("yonigozlan/GotOcr2-hf", use_temp_dir=True)
model.push_to_hub("yonigozlan/GOT-OCR-2.0-hf", use_temp_dir=True)
del state_dict, model

# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
GotOcr2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
model = GotOcr2ForConditionalGeneration.from_pretrained(model_path, device_map="auto")
processor = GotOcr2Processor.from_pretrained(model_path)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
)

inputs = processor(image, return_tensors="pt", format=True).to(model.device, dtype=model.dtype)
generate_ids = model.generate(**inputs, do_sample=False, num_beams=1, max_new_tokens=4)
decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
expected_output = "\\title{\nR"
print("Decoded output:", decoded_output)
assert decoded_output == expected_output
print("Model reloaded successfully.")
del model

# generation config
if instruct:
Expand Down Expand Up @@ -253,7 +273,7 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False,
tokenizer.save_pretrained(save_dir)

if push_to_hub:
tokenizer.push_to_hub("yonigozlan/GotOcr2-hf", use_temp_dir=True)
tokenizer.push_to_hub("yonigozlan/GOT-OCR-2.0-hf", use_temp_dir=True)

if instruct:
print("Saving chat template...")
Expand All @@ -275,7 +295,7 @@ def write_image_processor(save_dir: str, push_to_hub: bool = False):

image_processor.save_pretrained(save_dir)
if push_to_hub:
image_processor.push_to_hub("yonigozlan/GotOcr2-hf", use_temp_dir=True)
image_processor.push_to_hub("yonigozlan/GOT-OCR-2.0-hf", use_temp_dir=True)


def main():
Expand All @@ -300,13 +320,6 @@ def main():
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
instruct=args.instruct,
push_to_hub=args.push_to_hub,
)

write_tokenizer(
tokenizer_path="qwen.tiktoken",
save_dir=args.output_dir,
Expand All @@ -318,6 +331,12 @@ def main():
save_dir=args.output_dir,
push_to_hub=args.push_to_hub,
)
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
instruct=args.instruct,
push_to_hub=args.push_to_hub,
)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit f8e1ac9

Please sign in to comment.