Skip to content

Commit

Permalink
Latexocr paddle (PaddlePaddle#13401)
Browse files Browse the repository at this point in the history
* commit_test

* modified:   configs/rec/rec_latex_ocr.yml
	deleted:    ppocr/modeling/backbones/rec_resnetv2.py

* ntuple_solve

* style

* style

* style

* style

* style

* style

* style

* style

* style

* delete comment

* cla_email
  • Loading branch information
liuhongen1234567 authored Jul 22, 2024
1 parent c556b90 commit cf26f23
Show file tree
Hide file tree
Showing 34 changed files with 4,442 additions and 1 deletion.
126 changes: 126 additions & 0 deletions configs/rec/rec_latex_ocr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
Global:
use_gpu: True
epoch_num: 500
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/rec/latex_ocr/
save_epoch_step: 5
max_seq_len: 512
# evaluation is run every 60000 iterations (22 epoch)(batch_size = 56)
eval_batch_step: [0, 60000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/datasets/pme_demo/0000013.png
infer_mode: False
use_space_char: False
rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json
save_res_path: ./output/rec/predicts_latexocr.txt

Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Const
learning_rate: 0.0001

Architecture:
model_type: rec
algorithm: LaTeXOCR
in_channels: 1
Transform:
Backbone:
name: HybridTransformer
img_size: [192, 672]
patch_size: 16
num_classes: 0
embed_dim: 256
depth: 4
num_heads: 8
input_channel: 1
is_predict: False
is_export: False
Head:
name: LaTeXOCRHead
pad_value: 0
is_export: False
decoder_args:
attn_on_attn: True
cross_attend: True
ff_glu: True
rel_pos_bias: False
use_scalenorm: False

Loss:
name: LaTeXOCRLoss

PostProcess:
name: LaTeXOCRDecode
rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json

Metric:
name: LaTeXOCRMetric
main_indicator: exp_rate
cal_blue_score: False

Train:
dataset:
name: LaTeXOCRDataSet
data: ./train_data/LaTeXOCR/latexocr_train.pkl
min_dimensions: [32, 32]
max_dimensions: [672, 192]
batch_size_per_pair: 56
keep_smaller_batches: False
transforms:
- DecodeImage:
channel_first: False
- MinMaxResize:
min_dimensions: [32, 32]
max_dimensions: [672, 192]
- LatexTrainTransform:
bitmap_prob: .04
- NormalizeImage:
mean: [0.7931, 0.7931, 0.7931]
std: [0.1738, 0.1738, 0.1738]
order: 'hwc'
- LatexImageFormat:
- KeepKeys:
keep_keys: ['image']
loader:
shuffle: True
batch_size_per_card: 1
drop_last: False
num_workers: 0
collate_fn: LaTeXOCRCollator

Eval:
dataset:
name: LaTeXOCRDataSet
data: ./train_data/LaTeXOCR/latexocr_val.pkl
min_dimensions: [32, 32]
max_dimensions: [672, 192]
batch_size_per_pair: 10
keep_smaller_batches: True
transforms:
- DecodeImage:
channel_first: False
- MinMaxResize:
min_dimensions: [32, 32]
max_dimensions: [672, 192]
- LatexTestTransform:
- NormalizeImage:
mean: [0.7931, 0.7931, 0.7931]
std: [0.1738, 0.1738, 0.1738]
order: 'hwc'
- LatexImageFormat:
- KeepKeys:
keep_keys: ['image']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1
num_workers: 0
collate_fn: LaTeXOCRCollator
Binary file added doc/datasets/pme_demo/0000013.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/datasets/pme_demo/0000295.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/datasets/pme_demo/0000562.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions doc/doc_ch/algorithm_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,21 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广

已支持的公式识别算法列表(戳链接获取使用教程):
- [x] [CAN](./algorithm_rec_can.md)
- [x] [LaTeX-OCR](./algorithm_rec_latex_ocr.md)

在CROHME手写公式数据集上,算法效果如下:

|模型 |骨干网络|配置文件|ExpRate|下载链接|
| ----- | ----- | ----- | ----- | ----- |
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)|

在LaTeX-OCR印刷公式数据集上,算法效果如下:

| 模型 | 骨干网络 |配置文件 | BLEU score | normed edit distance | ExpRate |下载链接|
|-----------|------------| ----- |:-----------:|:---------------------:|:---------:| ----- |
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|


<a name="2"></a>

## 2. 端到端算法
Expand Down
171 changes: 171 additions & 0 deletions doc/doc_ch/algorithm_rec_latex_ocr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# 印刷数学公式识别算法-LaTeX-OCR

- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 pickle 标签文件生成](#3-1)
- [3.2 训练](#3-2)
- [3.3 评估](#3-3)
- [3.4 预测](#3-4)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)

<a name="1"></a>
## 1. 算法简介

原始项目:
> [https://github.com/lukas-blecher/LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)


<a name="model"></a>
`LaTeX-OCR`使用[`LaTeX-OCR印刷公式数据集`](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO)进行训练,在对应测试集上的精度如下:

| 模型 | 骨干网络 |配置文件 | BLEU score | normed edit distance | ExpRate |下载链接|
|-----------|------------| ----- |:-----------:|:---------------------:|:---------:| ----- |
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|

<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。

<a name="3"></a>
## 3. 模型训练、评估、预测

<a name="3-1"></a>

### 3.1 pickle 标签文件生成
[谷歌云盘](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO)中下载 formulae.zip 和 math.txt,之后,使用如下命令,生成 pickle 标签文件。

```shell
# 创建 LaTeX-OCR 数据集目录
mkdir -p train_data/LaTeXOCR
# 解压formulae.zip ,并拷贝math.txt
unzip -d train_data/LaTeXOCR path/formulae.zip
cp path/math.txt train_data/LaTeXOCR
# 将原始的 .txt 文件转换为 .pkl 文件,从而对不同尺度的图像进行分组
# 训练集转换
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/train --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
# 验证集转换
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/val --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
# 测试集转换
python ppocr/utils/formula_utils/math_txt2pkl.py --image_dir=train_data/LaTeXOCR/test --mathtxt_path=train_data/LaTeXOCR/math.txt --output_dir=train_data/LaTeXOCR/
```

### 3.2 模型训练

请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`LaTeX-OCR`识别模型时需要**更换配置文件**`LaTeX-OCR`[配置文件](../../configs/rec/rec_latex_ocr.yml)

#### 启动训练


具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```shell
#单卡训练 (默认训练方式)
python3 tools/train.py -c configs/rec/rec_latex_ocr.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_latex_ocr.yml
```

**注意:**

- 默认每训练22个epoch(60000次iteration)进行1次评估,若您更改训练的batch_size,或更换数据集,请在训练时作出如下修改
```
python3 tools/train.py -c configs/rec/rec_latex_ocr.yml -o Global.eval_batch_step=[0,{length_of_dataset//batch_size*22}]
```

<a name="3-2"></a>
### 3.3 评估

可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar),使用如下命令进行评估:

```shell
# 注意将pretrained_model的路径设置为本地路径。若使用自行训练保存的模型,请注意修改路径和文件名为{path/to/weights}/{model_name}。
# 验证集评估
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True
# 测试集评估
python3 tools/eval.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Metric.cal_blue_score=True Eval.dataset.data=./train_data/LaTeXOCR/latexocr_test.pkl
```

<a name="3-3"></a>
### 3.4 预测

使用如下命令进行单张图片预测:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/infer_rec.py -c configs/rec/rec_latex_ocr.yml -o Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True Global.infer_img='./doc/datasets/pme_demo/0000013.png' Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/datasets/pme_demo/'。
```

<a name="4"></a>
## 4. 推理部署

<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar) ),可以使用如下命令进行转换:

```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/rec/rec_latex_ocr.yml -o Global.pretrained_model=./rec_latex_ocr_train/best_accuracy.pdparams Global.save_inference_dir=./inference/rec_latex_ocr_infer/ Architecture.Backbone.is_predict=True Architecture.Backbone.is_export=True Architecture.Head.is_export=True

# 目前的静态图模型支持的最大输出长度为512
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请检查配置文件中的`rec_char_dict_path`是否为所需要的字典文件。
- [转换后模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_infer.tar)

转换成功后,在目录下有三个文件:
```
/inference/rec_latex_ocr_infer/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```

执行如下命令进行模型推理:

```shell
python3 tools/infer/predict_rec.py --image_dir='./doc/datasets/pme_demo/0000295.png' --rec_algorithm="LaTeXOCR" --rec_batch_num=1 --rec_model_dir="./inference/rec_latex_ocr_infer/" --rec_char_dict_path="./ppocr/utils/dict/latex_ocr_tokenizer.json"

# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/datasets/pme_demo/'。
```
&nbsp;

![测试图片样例](../datasets/pme_demo/0000295.png)

执行命令后,上面图像的预测结果(识别的文本)会打印到屏幕上,示例如下:
```shell
Predicts of ./doc/datasets/pme_demo/0000295.png:\zeta_{0}(\nu)=-{\frac{\nu\varrho^{-2\nu}}{\pi}}\int_{\mu}^{\infty}d\omega\int_{C_{+}}d z{\frac{2z^{2}}{(z^{2}+\omega^{2})^{\nu+1}}}{\tilde{\Psi}}(\omega;z)e^{i\epsilon z}~~~,
```


**注意**

- 需要注意预测图像为**白底黑字**,即手写公式部分为黑色,背景为白色的图片。
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中 LaTeX-OCR 的预处理为您的预处理方法。


<a name="4-2"></a>
### 4.2 C++推理部署

由于C++预处理后处理还未支持 LaTeX-OCR,所以暂未支持

<a name="4-3"></a>
### 4.3 Serving服务化部署

暂不支持

<a name="4-4"></a>
### 4.4 更多推理部署

暂不支持

<a name="5"></a>
## 5. FAQ

1. LaTeX-OCR 数据集来自于[LaTeXOCR源repo](https://github.com/lukas-blecher/LaTeX-OCR)
9 changes: 9 additions & 0 deletions doc/doc_en/algorithm_overview_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ On the TextZoom public dataset, the effect of the algorithm is as follows:
Supported formula recognition algorithms (Click the link to get the tutorial):

- [x] [CAN](./algorithm_rec_can_en.md)
- [x] [LaTeX-OCR](./algorithm_rec_latex_ocr_en.md)


On the CROHME handwritten formula dataset, the effect of the algorithm is as follows:

Expand All @@ -145,6 +147,13 @@ On the CROHME handwritten formula dataset, the effect of the algorithm is as fol
|CAN|DenseNet|[rec_d28_can.yml](../../configs/rec/rec_d28_can.yml)|51.72%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_d28_can_train.tar)|


On the LaTeX-OCR printed formula dataset, the effect of the algorithm is as follows:

| Model | Backbone |config| BLEU score | normed edit distance | ExpRate |Download link|
|-----------|----------| ---- |:-----------:|:---------------------:|:---------:| ----- |
| LaTeX-OCR | Hybrid ViT |[rec_latex_ocr.yml](../../configs/rec/rec_latex_ocr.yml)| 0.8821 | 0.0823 | 40.01% |[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_latex_ocr_train.tar)|


<a name="2"></a>

## 2. End-to-end OCR Algorithms
Expand Down
Loading

0 comments on commit cf26f23

Please sign in to comment.