-
Notifications
You must be signed in to change notification settings - Fork 164
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [Add] new pytorch model mt5 * [Fix] None value Constant * [Fix] torch None value * [Update] dockerfile for torch yolov5 * [Add] pytorch yolov5 * [Update] yolov5 model zoo * [Update] yolov5 req pip
- Loading branch information
Showing
10 changed files
with
508 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# 转换 PyTorch 的 Yolov5 至 PaddlePaddle 模型 | ||
|
||
我们可以使用 [ultralytics/yolov5](https://github.com/ultralytics/yolov5) 和 [model files](https://pytorch.org/hub/ultralytics_yolov5/) 进行模型的转换。 | ||
|
||
## 模型转换 | ||
|
||
以下为示例代码: | ||
|
||
|
||
``` python | ||
|
||
import torch | ||
import numpy as np | ||
|
||
# 1. load model | ||
torch_model = torch.hub.load( | ||
'../dataset/yolov5/yolov5', | ||
source='local', | ||
model='custom', | ||
path='../dataset/yolov5/yolov5s.pt') | ||
torch_model.eval() | ||
|
||
# 2. load input img | ||
img = np.load("../dataset/yolov5/input.npy") | ||
|
||
# 3. trace once | ||
# https://github.com/ultralytics/yolov5/issues/9341 | ||
# torch need to load yolo model twice!!! | ||
# so, we load one time before converting to paddle!!! | ||
try: | ||
torch.jit.trace(torch_model, torch.tensor(img)) | ||
except: | ||
pass | ||
|
||
|
||
# 4. convert model | ||
from x2paddle.convert import pytorch2paddle | ||
|
||
save_dir = "pd_model_trace" | ||
jit_type = "trace" | ||
|
||
pytorch2paddle(torch_model, | ||
save_dir, | ||
jit_type, | ||
[torch.tensor(img)], | ||
disable_feedback=True) | ||
|
||
``` | ||
|
||
大体可分为四个步骤: | ||
|
||
1. 加载模型 | ||
|
||
可以使用 [YOLOv5](https://pytorch.org/hub/ultralytics_yolov5/) 提供的加载模型方式,也可以采用上述代码中的本地加载方式。 | ||
|
||
这里使用接口 `torch.hub.load`,其中,各个参数的意义为: | ||
|
||
- `'../dataset/yolov5/yolov5'`,表示 `ultralytics/yolov5` 的 repo 所在的目录。因此,需要先手动把 [ultralytics/yolov5](https://github.com/ultralytics/yolov5) 拉取到相应的目录中。 | ||
- `source='local'`,表示使用本地 repo,对应上述第一个参数。 | ||
- `model='custom'`,表示此处使用本地的模型,也可以把 [model files](https://pytorch.org/hub/ultralytics_yolov5/) 中的模型下载到相应目录。 | ||
- `path='../dataset/yolov5/yolov5s.pt'`,表示本地模型文件的位置,对应上述第三个参数。 | ||
|
||
2. 加载测试数据 | ||
|
||
此处加载的数据类型为 `numpy`。注意,不同的数据类型,模型的输出结果类型不同。 | ||
|
||
3. 第一次 trace | ||
|
||
根据 https://github.com/ultralytics/yolov5/issues/9341 中的讨论,yolov5 模型在使用 `torch.jit.trace` 时可能会报错,而且需要 trace 两次! | ||
|
||
因此,此处务必先 trace 一次,否则后续转换出错。 | ||
|
||
4. 转换模型 | ||
|
||
直接调用接口转换模型即可。 | ||
|
||
由于 (3) 中提到的问题,因此,建议采用接口调用的方式进行模型转换,CLI 的命令模式转换方式可能出错。 | ||
|
||
另外,yolov5 的模型使用 `script` 方式转换也可能出错,请关注 PyTorch 与 [ultralytics/yolov5](https://github.com/ultralytics/yolov5) 是否有跟进此问题。 | ||
|
||
## 模型使用 | ||
|
||
转换成功后,会在当前目录生成 `pd_model_trace` ,相关模型结构代码与模型文件存储在此处,可直接使用。如: | ||
|
||
``` python | ||
|
||
import paddle | ||
import numpy as np | ||
|
||
img = np.load("../dataset/yolov5/input.npy") | ||
|
||
# trace | ||
paddle.enable_static() | ||
exe = paddle.static.Executor(paddle.CPUPlace()) | ||
[prog, inputs, outputs] = paddle.static.load_inference_model( | ||
path_prefix="pd_model_trace/inference_model/model", executor=exe) | ||
result = exe.run(prog, feed={inputs[0]: img}, fetch_list=outputs) | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Converting PyTorch's Yolov5 to PaddlePaddle | ||
|
||
We can use [ultralytics/yolov5](https://github.com/ultralytics/yolov5) and [model files](https://pytorch.org/hub/ultralytics_yolov5/) for model conversion. | ||
|
||
## Model conversion | ||
|
||
The following is sample code: | ||
|
||
|
||
``` python | ||
|
||
|
||
import torch | ||
import numpy as np | ||
|
||
# 1. load model | ||
torch_model = torch.hub.load( | ||
'../dataset/yolov5/yolov5', | ||
source='local', | ||
model='custom', | ||
path='../dataset/yolov5/yolov5s.pt') | ||
torch_model.eval() | ||
|
||
# 2. load input img | ||
img = np.load("../dataset/yolov5/input.npy") | ||
|
||
# 3. trace once | ||
# https://github.com/ultralytics/yolov5/issues/9341 | ||
# torch need to load yolo model twice!!! | ||
# so, we load one time before converting to paddle!!! | ||
try: | ||
torch.jit.trace(torch_model, torch.tensor(img)) | ||
except: | ||
pass | ||
|
||
|
||
# 4. convert model | ||
from x2paddle.convert import pytorch2paddle | ||
|
||
save_dir = "pd_model_trace" | ||
jit_type = "trace" | ||
|
||
pytorch2paddle(torch_model, | ||
save_dir, | ||
jit_type, | ||
[torch.tensor(img)], | ||
disable_feedback=True) | ||
|
||
``` | ||
|
||
It can be roughly divided into four steps: | ||
|
||
1. Load the model | ||
|
||
The model can be loaded using the method provided by [YOLOv5](https://pytorch.org/hub/ultralytics_yolov5/) or locally. | ||
|
||
Here the interface `torch.hub.load` is used, where the meaning of each parameter is: | ||
|
||
- `'... /dataset/yolov5/yolov5'`, which indicates the directory where the repo for `ultralytics/yolov5` is located. Therefore, you need to manually pull [ultralytics/yolov5](https://github.com/ultralytics/yolov5) into the appropriate directory first. | ||
- `source='local'` means to use a local repo, which corresponds to the first parameter above. | ||
- `model='custom'`, means use local model, download the model in [model files](https://pytorch.org/hub/ultralytics_yolov5/) to the corresponding directory. | ||
- `path='... /dataset/yolov5/yolov5s.pt'`, indicates the location of the local model files, which corresponds to third parameter above. | ||
|
||
2. Loading test data | ||
|
||
The data type loaded here is `numpy`. Note that the model outputs different types of results for different data types. | ||
|
||
3. First trace | ||
|
||
According to the discussion in https://github.com/ultralytics/yolov5/issues/9341, the yolov5 model may report an error when using `torch.jit.trace` and needs to be traced twice! | ||
|
||
Therefore, it is important to trace once here, otherwise subsequent conversions will fail. | ||
|
||
4. Converting Models | ||
|
||
The interface of conversion model can be called directly. | ||
|
||
Because of the problems mentioned in (3), it is recommended to use the interface call to convert the model, as the CLI command mode conversion may cause errors. | ||
|
||
Yolov5 models can not be converted using `script`, please check with PyTorch and [ultralytics/yolov5](https://github.com/ultralytics/yolov5) to see if they have followed up on this issue. | ||
|
||
## Model usage | ||
|
||
After successful conversion, `pd_model_trace` will be generated in the current directory, where the related model structure code and model files are stored and can be used directly.For example: | ||
|
||
``` python | ||
|
||
import paddle | ||
import numpy as np | ||
|
||
img = np.load("../dataset/yolov5/input.npy") | ||
|
||
# trace | ||
paddle.enable_static() | ||
exe = paddle.static.Executor(paddle.CPUPlace()) | ||
[prog, inputs, outputs] = paddle.static.load_inference_model( | ||
path_prefix="pd_model_trace/inference_model/model", executor=exe) | ||
result = exe.run(prog, feed={inputs[0]: img}, fetch_list=outputs) | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import os | ||
import argparse | ||
import pickle | ||
import numpy as np | ||
import sys | ||
|
||
sys.path.append('../tools/') | ||
|
||
from predict import BenchmarkPipeline | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Model inference') | ||
parser.add_argument('--batch_size', | ||
dest='batch_size', | ||
help='Mini batch size of one gpu or cpu.', | ||
type=int, | ||
default=1) | ||
|
||
def str2bool(v): | ||
return v.lower() in ("true", "t", "1") | ||
|
||
parser.add_argument("--use_gpu", type=str2bool, default=True) | ||
parser.add_argument("--enable_trt", | ||
type=str2bool, | ||
default=True, | ||
help="enable trt") | ||
parser.add_argument("--cpu_threads", type=int, default=1) | ||
parser.add_argument("--enable_mkldnn", type=str2bool, default=True) | ||
return parser.parse_args() | ||
|
||
|
||
def main(args): | ||
data = np.load("../dataset/yolov5/input.npy") | ||
pytorch_result = np.load("../dataset/yolov5/output.npy") | ||
# trace | ||
benchmark_pipeline = BenchmarkPipeline( | ||
model_dir="pd_model_trace/inference_model/", | ||
model_name='yolov5_trace', | ||
use_gpu=args.use_gpu, | ||
enable_trt=args.enable_trt, | ||
cpu_threads=args.cpu_threads, | ||
enable_mkldnn=args.enable_mkldnn) | ||
benchmark_pipeline.run_benchmark(data=data, | ||
pytorch_result=pytorch_result, | ||
warmup=1, | ||
repeats=1) | ||
benchmark_pipeline.analysis_operators( | ||
model_dir="pd_model_trace/inference_model/") | ||
benchmark_pipeline.report() | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
import numpy as np | ||
|
||
torch_model = torch.hub.load('../dataset/yolov5/yolov5', | ||
model='custom', | ||
source='local', | ||
path='../dataset/yolov5/yolov5s.pt') | ||
torch_model.eval() | ||
|
||
img = np.load("../dataset/yolov5/input.npy") | ||
|
||
save_dir = "pd_model_trace" | ||
jit_type = "trace" | ||
|
||
# https://github.com/ultralytics/yolov5/issues/9341 | ||
# torch need to load yolo model twice!!! | ||
# so, we load one time before converting to paddle!!! | ||
try: | ||
torch.jit.trace(torch_model, torch.tensor(img)) | ||
except: | ||
pass | ||
|
||
from x2paddle.convert import pytorch2paddle | ||
|
||
pytorch2paddle(torch_model, | ||
save_dir, | ||
jit_type, [torch.tensor(img)], | ||
disable_feedback=True) |
Oops, something went wrong.