Skip to content

Commit

Permalink
ptq english
Browse files Browse the repository at this point in the history
  • Loading branch information
tp-nan committed Aug 30, 2023
1 parent c6d9c10 commit 4e45d57
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 125 deletions.
118 changes: 63 additions & 55 deletions docs/faq/onnx.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,45 @@ type: explainer


:::note
随着 NVIDIA 将重心转移到 [`TorchTensorrt`](https://github.com/pytorch/TensorRT), [`torch2trt`](https://github.com/NVIDIA-AI-IOT/torch2trt) 已经停止大规模维护。tensorrt官方将对onnx的支持推到了最高程度。在我们的所有已知实践中,通过静态onnx组合, 动态onnx,预生成tensorrt模型等途径,torchpipe能完整替代 `torch2trt`.
In all of our known practices, TorchPipe can completely replace [`torch2trt`](https://github.com/NVIDIA-AI-IOT/torch2trt) through static ONNX composition, dynamic ONNX, pre-generated TensorRT models, and other methods.
:::

## torch转onnx
## Torch to ONNX Conversion

框架优先支持动态 `batch` 或者 `batchsize==1` 的静态 `batch`。实际中,有些模型无法转为动态尺度,或者比较容易出错,
我们也支持[**同时加载多个不同静态batchsize的模型**](../Intra-node/schedule#single_node_combine),去模拟动态尺度。以下说明主要针对导出动态 batchsize 模型。
The framework prioritizes dynamic `batch` or static `batch` with `batchsize==1`. In reality, some models cannot be converted to dynamic scale or are prone to errors. We also support [**loading multiple models with different static batch sizes at the same time**](../Intra-node/schedule#single_node_combine) to simulate dynamic scale. The following instructions mainly apply to exporting dynamic batch size models.

:::caution 动态batch的导出
- 以下操作导致动态batch不可用: ``x.view(int(x.size(0)), -1)``. 需要检查模型文件中是否存在将batch维度写死的情况,比如:x.view(int(x.size(0)), -1, 1, 1),x.reshape(int(x.size(0)), -1, 1, 1)等,这可能会导致转换onnx后动态batch出现问题。注意,在Transformer-like的网络中,batch维度不一定在第0维度。
- batch维度指定为动态大小时,低版本tensorrt对此处理能力弱一些,冗余算子多一些。比如对于 ``x.view(x.size(0), -1)``,会在onnx中引入Gather等算子来计算x的第一个维度。可修改为 ``x = x.view(-1, int(x.size(1)*x.size(2)*x.size(3)))`` 或者 ``x = torch.flatten(x, 1)``。此项非必需。
- 对于部分模型(tensorrt8.5.1, lstm 和 transformer),batch维度和非batch维度同时动态时,可能消耗更多资源 :
- 对于layerNorm层以及动态batch的Transformer-like的网络,推荐使用opset>=17, tensorrt>=8.6.1
:::caution Exporting Dynamic Batch Size Models
- The following operations make dynamic batch size unavailable: ``x.view(int(x.size(0)), -1)``. Check if the model file has hardcoded the batch dimension, such as ``x.view(int(x.size(0)), -1, 1, 1)``, ``x.reshape(int(x.size(0)), -1, 1, 1)``, etc., which may cause problems with dynamic batch size after converting to ONNX. Note that in Transformer-like networks, the batch dimension is not necessarily in the 0th dimension.
- When the batch dimension is specified as dynamic size, low-version TensorRT has weaker processing capabilities and more redundant operators. For example, for ``x.view(x.size(0), -1)``, Gather and other operators will be introduced in ONNX to calculate the first dimension of x. It can be modified to ``x = x.view(-1, int(x.size(1)*x.size(2)*x.size(3)))`` or ``x = torch.flatten(x, 1)``. This is not necessary.
- For some models (TensorRT 8.5.1, LSTM, and Transformer), when the batch dimension and non-batch dimension are both dynamic, more resources may be consumed:
- For LayerNorm layers and Transformer-like networks with dynamic batch size, opset>=17 and TensorRT>=8.6.1 are recommended.
:::

```bash
# batch和非batch同时动态,需要9ms(推理输入大小为optShapes=input:1x1000x80,mask:1x1x1000:
# When both batch and non-batch dimensions are dynamic, it takes 9ms (inference input size is optShapes=input:1x1000x80,mask:1x1x1000):
/opt/tensorrt/bin/trtexec --onnx=test_fp32.onnx --shapes=input:1x1000x80,mask:1x1x1000 --workspace=64000 \
--minShapes=input:1x20x80,mask:1x1x20 \
--optShapes=input:1x1000x80,mask:1x1x1000 \
--maxShapes=input:4x2000x80,mask:4x1x2000


# 固定batchsize==1,只需要4.6ms
# When batchsize==1 is fixed, it only takes 4.6ms:
/opt/tensorrt/bin/trtexec --onnx=test_fp32.onnx --shapes=input:1x1000x80,mask:1x1x1000 --workspace=64000 \
--minShapes=input:1x20x80,mask:1x1x20 \
--optShapes=input:1x1000x80,mask:1x1x1000 \
--maxShapes=input:1x2000x80,mask:1x1x2000
```
此时推荐**只将其中一个维度离散化**
At this point, it is recommended to **discretize only one dimension**.

:::

:::tip 最佳实践
- 可能的情况下,保持batch维度在第0维度,长度为默认状态(也就是-1),以便去除冗余算子。
- 使用onnx-simplify进行优化
- [更小的优化范围通常意味着更快的速度和消耗更少的资源](https://github.com/NVIDIA/TensorRT/issues/1166#issuecomment-815551064)
:::tip Best Practices
- Whenever possible, keep the batch dimension in the 0th dimension with a length of the default state (i.e., -1) to remove redundant operators.
- Use onnx-simplify for optimization.
- [Smaller optimization ranges usually mean faster speeds and less resource consumption](https://github.com/NVIDIA/TensorRT/issues/1166#issuecomment-815551064).
:::


修改完网络后,可以利用下面代码,将pytorch模型转换为onnx模型。
After modifying the network, you can use the following code to convert the PyTorch model to an ONNX model:

```python
x = torch.randn(1,*input_shape).cuda()
Expand All @@ -60,8 +59,8 @@ torch.onnx.export(torch_model,
onnx_save_path,
opset_version=17,
do_constant_folding=True,
input_names=["input"], # 输入名
output_names=[f"output_{i}" for i in range(out_size)], # 输出名
input_names=["input"], # input name
output_names=[f"output_{i}" for i in range(out_size)], # output names
dynamic_axes=out)

import onnx
Expand All @@ -73,59 +72,68 @@ model_simp, check = onnx_simplifier.simplify(onnx_model, check_n = 0)
onnx.save(model_simp, onnx_save_path)
```

<details><summary>为了方便,针对这步我们提供了torchpipe.utils.models.onnx_export小工具</summary>
<details><summary>`torchpipe.utils.models.onnx_export`(effective from 0.3.2b1):</summary>

- This tool can convert PyTorch models to ONNX models and save them locally. It only supports single input.
- It supports dynamic batch and comes with onnx-simplify optimization.


```python
def onnx_export(model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction], onnx_path, input = None, opset = 17):
```
:::tip Parameters
- **model** - PyTorch model.
- **onnx_path** - Path to save the ONNX model.
- **input** - Model input. Defaults to torch.randn(1,3,224,224) if not set.
- **opset** - ONNX opset.
:::

- 该工具可以实现将PyTorch模型转换为ONNX模型并保存到本地
<details><summary>Example Code</summary>

```python
import os
import os, tempfile
from torchvision import models
import torch
import torchpipe

## export onnx
m = models.resnet50(weights=None).eval()
onnx_path = os.path.join("/tmp", f"resnet50.onnx")
input = torch.randn(1, 3, 224, 224)
torchpipe.utils.models.onnx_export(m, onnx_path, input, opset=17)

onnx_path = os.path.join(tempfile.gettempdir(), f"resnet50.onnx")
torchpipe.utils.models.onnx_export(m, onnx_path, torch.randn(1, 3, 224, 224), opset=17)
```
</details>
</details>

### 转换失败说明


torch转onnx经常遇到转换失败的情况。可采取的方法有:
When converting from torch to ONNX, it is common to encounter conversion failures. Here are some methods that can be used:


- 动态维度保持动态, 比如对于yolox:
- Keep dynamic dimensions dynamic. For example, for YOLOX:

```python
x = x.view(int(x.size(0)), -1, 1, 1)
# 改为
# Change to:
x = x.flatten(1).unsqueeze(2).unsqueeze(2)

x = x.view(int(x.size(0)), -1)
# 改为
# Change to:
x = x.view(-1, int(x.size(1)*x.size(2)*x.size(3)))
```

- bool值改为float:
- Change boolean values to float:

```python
tgt_padding_mask = (tgt_in == self.eos_id)
# 改为
# Change to:
tgt_padding_mask = (tgt_in == self.eos_id).float()
```

- 采用[onnx-simplify](#onnx-smi)简化


- 版本原因, 尝试不同版本:
- 尽可能使用最新版,比如采用 onnx opset >= 14 和 tensorrt >= 8.2
- 对于tensorrt7,推荐onnx版本1.9.0, onnx opset = 11

- 尝试使用trtexec进行模型转换:
- Simplify using [onnx-simplify](#onnx-smi).
- Try different versions due to version issues:
- Use the latest version as much as possible, such as onnx opset >= 14 and tensorrt >= 8.2.
- For tensorrt7, it is recommended to use onnx version 1.9.0 and onnx opset = 11.
- Try using trtexec for model conversion:

```python

Expand All @@ -137,33 +145,33 @@ tgt_padding_mask = (tgt_in == self.eos_id).float()
--saveEngine=test_fp32.trt
```

## onnx 相关工具

## ONNX Related Tools

### [onnx-simplify](https://github.com/daquexian/onnx-simplifier) {#onnx-smi}
简化模型结构的工具

Tools for simplifying model structure:

```python
pip install onnx onnxsim
onnxsim input.onnx output.onnx
```
### [netron](https://github.com/lutzroeder/netron)

用于可视化onnx模型的工具。
A tool for visualizing ONNX models.

运行 pip install netronnetron [FILE] 或者 netron.start('[FILE]').
Run `pip install netron` and `netron [FILE]` or `netron.start('[FILE]')`.

### [ONNX GraphSurgeon](https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon)

ONNX GraphSurgeon 是tensorrt官方发布的一款用于修改onnx结构的工具。

ONNX GraphSurgeon is a tool released by TensorRT for modifying ONNX structures.

### [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy)

nvidia官方用于测试tensorrt或者onnx的工具。提供模型转换功能,对于fp16精度损失可进行调试,指定层不使用fp16.
Polygraphy is a tool provided by NVIDIA for testing TensorRT or ONNX. It provides model conversion functionality and allows for debugging of FP16 precision loss. It also allows for specifying layers that should not use FP16..


## 参考连接
[PyTorch ONNX 详解](https://zhuanlan.zhihu.com/p/498425043)
[ONNX 模型的修改与调试](https://zhuanlan.zhihu.com/p/516920606)
[TensorRT 教程 | 基于 8.6.1 版本](https://www.bilibili.com/video/BV1jj411Z7wG/?spm_id_from=333.999.0.0&vd_source=c31de98543aa977b5899e24bdd5d8f89)
[quantization tutorial](https://github.com/NVIDIA/TensorRT/tree/release/8.6/quickstart/quantization_tutorial)
## Reference Links
- [PyTorch to ONNX Conversion Tutorial](https://zhuanlan.zhihu.com/p/498425043)
- [Modifying and Debugging ONNX Models](https://zhuanlan.zhihu.com/p/516920606)
- [TensorRT Tutorial | Based on version 8.6.1](https://www.bilibili.com/video/BV1jj411Z7wG/?spm_id_from=333.999.0.0&vd_source=c31de98543aa977b5899e24bdd5d8f89)
- [Quantization Tutorial](https://github.com/NVIDIA/TensorRT/tree/release/8.6/quickstart/quantization_tutorial)
2 changes: 1 addition & 1 deletion docs/introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ To enhance the peak throughput of deep learning serving, various challenges must

There are some industry practices, such as [triton inference server](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models), [Alimama high_service(in chinese)](https://mp.weixin.qq.com/s/Fd2GNXqO3wl3FrA7Wli3jA), and [Meituan Vision GPU Inference Service Deployment Architecture Optimization Practice(in chinese)](https://zhuanlan.zhihu.com/p/605094862).

One common complaint from users of the Triton Inference Server is that in a system with multiple intertwined nodes, a lot of business logic needs to be completed on the client side and then called through RPC to the server, which can be cumbersome. For performance reasons, unconventional methods such as shared memory, ensemble, and [BLS](https://github.com/triton-inference-server/python_backend#business-logic-scripting) must be considered.
One common complaint from users of the Triton Inference Server is that in a system with multiple intertwined nodes, a lot of business logic needs to be completed on the client side and then called through RPC to the server, which can be cumbersome. For performance reasons, unconventional methods such as shared memory, ensemble, and [Business Logic Scripting(BLS)](https://github.com/triton-inference-server/python_backend#business-logic-scripting) must be considered.

To address these issues, TorchPipe provides a thread-safe function interface for the PyTorch frontend and a fine-grained backend extension for users, by delving into PyTorch's C++ calculation backend and CUDA stream management, as well as modeling domain-specific languages for multiple nodes.

Expand Down
Loading

0 comments on commit 4e45d57

Please sign in to comment.