diff --git a/.github/workflows/linux-x64-cpu-clang.yml b/.github/workflows/linux-x64-cpu-clang.yml index b8f5005f5384..2abadec9d468 100644 --- a/.github/workflows/linux-x64-cpu-clang.yml +++ b/.github/workflows/linux-x64-cpu-clang.yml @@ -12,6 +12,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -25,6 +26,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: linux-clang: diff --git a/.github/workflows/linux-x64-cpu-gcc.yml b/.github/workflows/linux-x64-cpu-gcc.yml index aa961ee7ab63..7bf18ea4c6cc 100644 --- a/.github/workflows/linux-x64-cpu-gcc.yml +++ b/.github/workflows/linux-x64-cpu-gcc.yml @@ -12,6 +12,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -25,6 +26,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: linux-gcc: diff --git a/.github/workflows/linux-x64-gpu-clang.yml b/.github/workflows/linux-x64-gpu-clang.yml index b0a1cbfbdc47..a95e42f3e14d 100644 --- a/.github/workflows/linux-x64-gpu-clang.yml +++ b/.github/workflows/linux-x64-gpu-clang.yml @@ -12,6 +12,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -25,6 +26,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: linux-clang-gpu: diff --git a/.github/workflows/linux-x64-gpu-gcc.yml b/.github/workflows/linux-x64-gpu-gcc.yml index 005be9797678..a5c8351b9cfa 100644 --- a/.github/workflows/linux-x64-gpu-gcc.yml +++ b/.github/workflows/linux-x64-gpu-gcc.yml @@ -12,6 +12,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -25,6 +26,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: linux-gcc-gpu: diff --git a/.github/workflows/macos-x64-cpu.yml b/.github/workflows/macos-x64-cpu.yml index 73da777d887b..29738c77893d 100644 --- a/.github/workflows/macos-x64-cpu.yml +++ b/.github/workflows/macos-x64-cpu.yml @@ -11,6 +11,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -23,6 +24,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' env: DEVELOPER_DIR: /Applications/Xcode_12.4.app/Contents/Developer diff --git a/.github/workflows/macos-x64-gpu.yml b/.github/workflows/macos-x64-gpu.yml index 2431b1dab27b..df9ba0515582 100644 --- a/.github/workflows/macos-x64-gpu.yml +++ b/.github/workflows/macos-x64-gpu.yml @@ -12,6 +12,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -25,6 +26,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' env: DEVELOPER_DIR: /Applications/Xcode_12.4.app/Contents/Developer diff --git a/.github/workflows/pnnx.yml b/.github/workflows/pnnx.yml new file mode 100644 index 000000000000..2dc10aa7f32d --- /dev/null +++ b/.github/workflows/pnnx.yml @@ -0,0 +1,66 @@ +name: pnnx +on: + push: + branches: [master] + paths: + - '.github/workflows/pnnx.yml' + - 'tools/pnnx/**' + - '!tools/pnnx/README.md' + pull_request: + branches: [master] + paths: + - '.github/workflows/pnnx.yml' + - 'tools/pnnx/**' + - '!tools/pnnx/README.md' +jobs: + ubuntu: + runs-on: ubuntu-20.04 + + strategy: + fail-fast: false + matrix: + include: + - torch-version: 1.8.1 + torchvision-version: 0.9.1 + + - torch-version: 1.9.1 + torchvision-version: 0.10.1 + + - torch-version: 1.10.0 + torchvision-version: 0.11.1 + + steps: + - name: cancel-previous-runs + uses: styfle/cancel-workflow-action@0.9.1 + with: + access_token: ${{ secrets.GITHUB_TOKEN }} + + - name: setup pytorch-${{ matrix.torch-version }} + run: | + pip install torch==${{ matrix.torch-version }}+cpu torchvision==${{ matrix.torchvision-version }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + + - uses: actions/checkout@v2 + with: + submodules: true + + - name: build-ncnn + run: | + python -m pip install --upgrade pip + pip install pytest setuptools wheel twine + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Release -DNCNN_PYTHON=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j 2 + cd .. + pip install . + + - name: build-pnnx + run: | + cd tools/pnnx + mkdir build && cd build + cmake -DTorch_INSTALL_DIR="$HOME/.local/lib/python3.8/site-packages/torch" -DCMAKE_BUILD_TYPE=Release .. + cmake --build . -j 2 + + - name: test + run: | + cd tools/pnnx + cd build && ctest --output-on-failure diff --git a/.github/workflows/windows-x64-cpu-vs2015.yml b/.github/workflows/windows-x64-cpu-vs2015.yml index a310ccf72ace..b8dd93a2436c 100644 --- a/.github/workflows/windows-x64-cpu-vs2015.yml +++ b/.github/workflows/windows-x64-cpu-vs2015.yml @@ -11,6 +11,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -23,6 +24,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: windows-vs2015: diff --git a/.github/workflows/windows-x64-cpu-vs2017.yml b/.github/workflows/windows-x64-cpu-vs2017.yml index b640a483e552..cb121d075920 100644 --- a/.github/workflows/windows-x64-cpu-vs2017.yml +++ b/.github/workflows/windows-x64-cpu-vs2017.yml @@ -11,6 +11,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -23,6 +24,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: windows-vs2017: diff --git a/.github/workflows/windows-x64-cpu-vs2019.yml b/.github/workflows/windows-x64-cpu-vs2019.yml index 02d4c9e0fd88..8390bbb03bdd 100644 --- a/.github/workflows/windows-x64-cpu-vs2019.yml +++ b/.github/workflows/windows-x64-cpu-vs2019.yml @@ -11,6 +11,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -23,6 +24,7 @@ on: - 'src/layer/x86/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: windows-vs2019: diff --git a/.github/workflows/windows-x64-gpu-vs2017.yml b/.github/workflows/windows-x64-gpu-vs2017.yml index fd4f726096e4..567b0c46da11 100644 --- a/.github/workflows/windows-x64-gpu-vs2017.yml +++ b/.github/workflows/windows-x64-gpu-vs2017.yml @@ -12,6 +12,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -25,6 +26,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: windows-vs2017-gpu: diff --git a/.github/workflows/windows-x64-gpu-vs2019.yml b/.github/workflows/windows-x64-gpu-vs2019.yml index 163810d698c2..ab4e243b0967 100644 --- a/.github/workflows/windows-x64-gpu-vs2019.yml +++ b/.github/workflows/windows-x64-gpu-vs2019.yml @@ -12,6 +12,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' pull_request: branches: [master] @@ -25,6 +26,7 @@ on: - 'src/layer/vulkan/**' - 'tests/**' - 'tools/**' + - '!tools/pnnx/**' - 'examples/**' jobs: windows-vs2019-gpu: diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt new file mode 100644 index 000000000000..c9dc3efe0d86 --- /dev/null +++ b/tools/pnnx/CMakeLists.txt @@ -0,0 +1,44 @@ +project(pnnx) +cmake_minimum_required(VERSION 3.10) + +# c++14 is required for using torch headers +set(CMAKE_CXX_STANDARD 14) + +#set(CMAKE_BUILD_TYPE debug) +#set(CMAKE_BUILD_TYPE relwithdebinfo) +#set(CMAKE_BUILD_TYPE release) + +option(PNNX_COVERAGE "build for coverage" OFF) + +#set(Torch_INSTALL_DIR "/home/nihui/.local/lib/python3.9/site-packages/torch" CACHE STRING "") +#set(Torch_INSTALL_DIR "/home/nihui/osd/pnnx/pytorch-v1.10.0/build/install" CACHE STRING "") +set(Torch_INSTALL_DIR "/home/nihui/osd/pnnx/libtorch" CACHE STRING "") +set(TorchVision_INSTALL_DIR "/home/nihui/osd/vision/build/install" CACHE STRING "") + +set(Torch_DIR "${Torch_INSTALL_DIR}/share/cmake/Torch") +set(TorchVision_DIR "${TorchVision_INSTALL_DIR}/share/cmake/TorchVision") + +find_package(Torch REQUIRED) +find_package(TorchVision QUIET) + +message(STATUS "Torch_VERSION = ${Torch_VERSION}") +message(STATUS "Torch_VERSION_MAJOR = ${Torch_VERSION_MAJOR}") +message(STATUS "Torch_VERSION_MINOR = ${Torch_VERSION_MINOR}") +message(STATUS "Torch_VERSION_PATCH = ${Torch_VERSION_PATCH}") + +if(Torch_VERSION VERSION_LESS "1.8") + message(FATAL_ERROR "pnnx only supports PyTorch >= 1.8") +endif() + +if(TorchVision_FOUND) + message(STATUS "Building with TorchVision") +else() + message(WARNING "Building without TorchVision") +endif() + +include_directories(${TORCH_INCLUDE_DIRS}) + +add_subdirectory(src) + +enable_testing() +add_subdirectory(tests) diff --git a/tools/pnnx/README.md b/tools/pnnx/README.md new file mode 100644 index 000000000000..1d45702f8751 --- /dev/null +++ b/tools/pnnx/README.md @@ -0,0 +1,624 @@ +# PNNX +PyTorch Neural Network eXchange(PNNX) is an open standard for PyTorch model interoperability. PNNX provides an open model format for PyTorch. It defines computation graph as well as high level operators strictly matches PyTorch. + +# Rationale +PyTorch is currently one of the most popular machine learning frameworks. We need to deploy the trained AI model to various hardware and environments more conveniently and easily. + +Before PNNX, we had the following methods: + +1. export to ONNX, and deploy with ONNX-runtime +2. export to ONNX, and convert onnx to inference-framework specific format, and deploy with TensorRT/OpenVINO/ncnn/etc. +3. export to TorchScript, and deploy with libtorch + +As far as we know, ONNX has the ability to express the PyTorch model and it is an open standard. People usually use ONNX as an intermediate representation between PyTorch and the inference platform. However, ONNX still has the following fatal problems, which makes the birth of PNNX necessary: + +1. ONNX does not have a human-readable and editable file representation, making it difficult for users to easily modify the computation graph or add custom operators. +2. The operator definition of ONNX is not completely in accordance with PyTorch. When exporting some PyTorch operators, glue operators are often added passively by ONNX, which makes the computation graph inconsistent with PyTorch and may impact the inference efficiency. +3. There are a large number of additional parameters designed to be compatible with various ML frameworks in the operator definition in ONNX. These parameters increase the burden of inference implementation on hardware and software. + +PNNX tries to define a set of operators and a simple and easy-to-use format that are completely contrasted with the python api of PyTorch, so that the conversion and interoperability of PyTorch models are more convenient. + +# Features + +1. [Human readable and editable format](#the-pnnxparam-format) +2. [Plain model binary in storage zip](#the-pnnxbin-format) +3. [One-to-one mapping of PNNX operators and PyTorch python api](#pnnx-operator) +4. [Preserve math expression as one operator](#pnnx-expression-operator) +5. [Preserve torch function as one operator](#pnnx-torch-function-operator) +6. [Preserve miscellaneous module as one operator](#pnnx-module-operator) +7. [Inference via exported PyTorch python code](#pnnx-python-inference) +8. [Tensor shape propagation](#pnnx-shape-propagation) +9. [Model optimization](#pnnx-model-optimization) +10. [Custom operator support](#pnnx-custom-operator) + +# Build TorchScript to PNNX converter + +1. Install PyTorch and TorchVision c++ library +2. Build PNNX with cmake + +# Usage + +1. Export your model to TorchScript + +```python +import torch +import torchvision.models as models + +net = models.resnet18(pretrained=True) +net = net.eval() + +x = torch.rand(1, 3, 224, 224) + +mod = torch.jit.trace(net, x) +torch.jit.save(mod, "resnet18.pt") +``` + +2. Convert TorchScript to PNNX + +```shell +pnnx resnet18.pt inputshape=[1,3,224,224] +``` + +Normally, you will get six files + +```resnet18.pnnx.param``` PNNX graph definition + +```resnet18.pnnx.bin``` PNNX model weight + +```resnet18_pnnx.py``` PyTorch script for inference, the python code for model construction and weight initialization + +```resnet18.ncnn.param``` ncnn graph definition + +```resnet18.ncnn.bin``` ncnn model weight + +```resnet18_ncnn.py``` pyncnn script for inference + +3. Visualize PNNX with Netron + +Open https://netron.app/ in browser, and drag resnet18.pnnx.param into it. + +4. PNNX command line options + +``` +Usage: pnnx [model.pt] [(key=value)...] + pnnxparam=model.pnnx.param + pnnxbin=model.pnnx.bin + pnnxpy=model_pnnx.py + ncnnparam=model.ncnn.param + ncnnbin=model.ncnn.bin + optlevel=2 + device=cpu/gpu + inputshape=[1,3,224,224],... + inputshape2=[1,3,320,320],... + customop=/home/nihui/.cache/torch_extensions/fused/fused.so,... + moduleop=models.common.Focus,models.yolo.Detect,... +Sample usage: pnnx mobilenet_v2.pt inputshape=[1,3,224,224] + pnnx yolov5s.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] device=gpu moduleop=models.common.Focus,models.yolo.Detect +``` + +# The pnnx.param format +### example +``` +7767517 +4 3 +pnnx.Input input 0 1 0 +nn.Conv2d conv_0 1 1 0 1 bias=1 dilation=(1,1) groups=1 in_channels=12 kernel_size=(3,3) out_channels=16 padding=(0,0) stride=(1,1) @bias=(16)f32 @weight=(16,12,3,3)f32 +nn.Conv2d conv_1 1 1 1 2 bias=1 dilation=(1,1) groups=1 in_channels=16 kernel_size=(2,2) out_channels=20 padding=(2,2) stride=(2,2) @bias=(20)f32 @weight=(20,16,2,2)f32 +pnnx.Output output 1 0 2 +``` +### overview +``` +[magic] +``` +* magic number : 7767517 +``` +[operator count] [operand count] +``` +* operator count : count of the operator line follows +* operand count : count of all operands +### operator line +``` +[type] [name] [input count] [output count] [input operands] [output operands] [operator params] +``` +* type : type name, such as Conv2d ReLU etc +* name : name of this operator +* input count : count of the operands this operator needs as input +* output count : count of the operands this operator produces as output +* input operands : name list of all the input blob names, separated by space +* output operands : name list of all the output blob names, separated by space +* operator params : key=value pair list, separated by space, operator weights are prefixed by ```@``` symbol, tensor shapes are prefixed by ```#``` symbol, input parameter keys are prefixed by ```$``` + +# The pnnx.bin format + +pnnx.bin file is a zip file with store-only mode(no compression) + +weight binary file has its name composed by operator name and weight name + +For example, ```nn.Conv2d conv_0 1 1 0 1 bias=1 dilation=(1,1) groups=1 in_channels=12 kernel_size=(3,3) out_channels=16 padding=(0,0) stride=(1,1) @bias=(16) @weight=(16,12,3,3)``` would pull conv_0.weight and conv_0.bias into pnnx.bin zip archive. + +weight binaries can be listed or modified with any archive application eg. 7zip + +![pnnx.bin](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/pnnx.bin.png) + +# PNNX operator +PNNX always preserve operators from what PyTorch python api provides. + +Here is the netron visualization comparision among ONNX, TorchScript and PNNX with the original PyTorch python code shown. + +```python +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.attention = nn.MultiheadAttention(embed_dim=256, num_heads=32) + + def forward(self, x): + x, _ = self.attention(x, x, x) + return x +``` + +|ONNX|TorchScript|PNNX| +|----|---|---| +|![MultiheadAttention.onnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/MultiheadAttention.onnx.png)|![MultiheadAttention.pt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/MultiheadAttention.pt.png)|![MultiheadAttention.pnnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/MultiheadAttention.pnnx.png)| + +# PNNX expression operator +PNNX trys to preserve expression from what PyTorch python code writes. + +Here is the netron visualization comparision among ONNX, TorchScript and PNNX with the original PyTorch python code shown. + +```python +import torch + +def foo(x, y): + return torch.sqrt((2 * x + y) / 12) +``` + +|ONNX|TorchScript|PNNX| +|---|---|---| +|![math.onnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/math.onnx.png)|![math.pt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/math.pt.png)|![math.pnnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/math.pnnx.png)| + +# PNNX torch function operator +PNNX trys to preserve torch functions and Tensor member functions as one operator from what PyTorch python api provides. + +Here is the netron visualization comparision among ONNX, TorchScript and PNNX with the original PyTorch python code shown. + +```python +import torch +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.normalize(x, eps=1e-3) + return x +``` + +|ONNX|TorchScript|PNNX| +|---|---|---| +|![function.onnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/function.onnx.png)|![function.pt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/function.pt.png)|![function.pnnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/function.pnnx.png)| + + +# PNNX module operator +Users could ask PNNX to keep module as one big operator when it has complex logic. + +The process is optional and could be enabled via moduleop command line option. + +After pass_level0, all modules will be presented in terminal output, then you can pick the intersting ones as module operators. +``` +############# pass_level0 +inline module = models.common.Bottleneck +inline module = models.common.C3 +inline module = models.common.Concat +inline module = models.common.Conv +inline module = models.common.Focus +inline module = models.common.SPP +inline module = models.yolo.Detect +inline module = utils.activations.SiLU +``` + +```bash +pnnx yolov5s.pt inputshape=[1,3,640,640] moduleop=models.common.Focus,models.yolo.Detect +``` + +Here is the netron visualization comparision among ONNX, TorchScript and PNNX with the original PyTorch python code shown. + +```python +import torch +import torch.nn as nn + +class Focus(nn.Module): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) +``` + +|ONNX|TorchScript|PNNX|PNNX with module operator| +|---|---|---|---| +|![focus.onnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/focus.onnx.png)|![focus.pt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/focus.pt.png)|![focus.pnnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/focus.pnnx.png)|![focus.pnnx2](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/focus.pnnx2.png)| + + +# PNNX python inference + +A python script will be generated by default when converting torchscript to pnnx. + +This script is the python code representation of PNNX and can be used for model inference. + +There are some utility functions for loading weight binary from pnnx.bin. + +You can even export the model torchscript AGAIN from this generated code! + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.linear_0 = nn.Linear(in_features=128, out_features=256, bias=True) + self.linear_1 = nn.Linear(in_features=256, out_features=4, bias=True) + + def forward(self, x): + x = self.linear_0(x) + x = F.leaky_relu(x, 0.15) + x = self.linear_1(x) + return x +``` + +```python +import os +import numpy as np +import tempfile, zipfile +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.linear_0 = nn.Linear(bias=True, in_features=128, out_features=256) + self.linear_1 = nn.Linear(bias=True, in_features=256, out_features=4) + + archive = zipfile.ZipFile('../../function.pnnx.bin', 'r') + self.linear_0.bias = self.load_pnnx_bin_as_parameter(archive, 'linear_0.bias', (256), 'float32') + self.linear_0.weight = self.load_pnnx_bin_as_parameter(archive, 'linear_0.weight', (256,128), 'float32') + self.linear_1.bias = self.load_pnnx_bin_as_parameter(archive, 'linear_1.bias', (4), 'float32') + self.linear_1.weight = self.load_pnnx_bin_as_parameter(archive, 'linear_1.weight', (4,256), 'float32') + archive.close() + + def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype): + return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype)) + + def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype): + _, tmppath = tempfile.mkstemp() + tmpf = open(tmppath, 'wb') + with archive.open(key) as keyfile: + tmpf.write(keyfile.read()) + tmpf.close() + m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy() + os.remove(tmppath) + return torch.from_numpy(m) + + def forward(self, v_x_1): + v_7 = self.linear_0(v_x_1) + v_input_1 = F.leaky_relu(input=v_7, negative_slope=0.150000) + v_12 = self.linear_1(v_input_1) + return v_12 +``` + +# PNNX shape propagation +Users could ask PNNX to resolve all tensor shapes in model graph and constify some common expressions involved when tensor shapes are known. + +The process is optional and could be enabled via inputshape command line option. + +```bash +pnnx shufflenet_v2_x1_0.pt inputshape=[1,3,224,224] +``` + +```python +def channel_shuffle(x: Tensor, groups: int) -> Tensor: + batchsize, num_channels, height, width = x.size() + channels_per_group = num_channels // groups + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + x = x.view(batchsize, -1, height, width) + + return x +``` + +|without shape propagation|with shape propagation| +|---|---| +|![noshapeinfer](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/noshapeinfer.png)|![shapeinfer](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/shapeinfer.pnnx.png)| + + +# PNNX model optimization + +|ONNX|TorchScript|PNNX without optimization|PNNX with optimization| +|---|---|---|---| +|![optlessonnx](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/optless.onnx.png)|![optlesspt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/optless.pt.png)|![optless](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/optless.pnnx.png)|![opt](https://raw.githubusercontent.com/nihui/ncnn-assets/master/pnnx/opt.pnnx.png)| + + +# PNNX custom operator + +```python +import os + +import torch +from torch.autograd import Function +from torch.utils.cpp_extension import load, _import_module_from_library + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'upfirdn2d.cpp'), + os.path.join(module_path, 'upfirdn2d_kernel.cu'), + ], + is_python_module=False +) + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + pad_x0 = pad[0] + pad_x1 = pad[1] + pad_y0 = pad[0] + pad_y1 = pad[1] + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + out_h = (in_h * up + pad_y0 + pad_y1 - kernel_h) // down + 1 + out_w = (in_w * up + pad_x0 + pad_x1 - kernel_w) // down + 1 + + out = torch.ops.upfirdn2d_op.upfirdn2d(input, kernel, up, up, down, down, pad_x0, pad_x1, pad_y0, pad_y1) + + out = out.view(-1, channel, out_h, out_w) + + return out +``` + +```cpp +#include + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int64_t up_x, int64_t up_y, int64_t down_x, int64_t down_y, + int64_t pad_x0, int64_t pad_x1, int64_t pad_y0, int64_t pad_y1) { + // operator body +} + +TORCH_LIBRARY(upfirdn2d_op, m) { + m.def("upfirdn2d", upfirdn2d); +} +``` + + + +# Supported PyTorch operator status + +| torch.nn | Is Supported | Export to ncnn | +|---------------------------|----|---| +|nn.AdaptiveAvgPool1d | :heavy_check_mark: | +|nn.AdaptiveAvgPool2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.AdaptiveAvgPool3d | :heavy_check_mark: | +|nn.AdaptiveMaxPool1d | :heavy_check_mark: | +|nn.AdaptiveMaxPool2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.AdaptiveMaxPool3d | :heavy_check_mark: | +|nn.AlphaDropout | | +|nn.AvgPool1d | :heavy_check_mark: | +|nn.AvgPool2d | :heavy_check_mark: | :heavy_check_mark:* | +|nn.AvgPool3d | :heavy_check_mark: | +|nn.BatchNorm1d | :heavy_check_mark: | :heavy_check_mark: | +|nn.BatchNorm2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.BatchNorm3d | :heavy_check_mark: | +|nn.Bilinear | | +|nn.CELU | :heavy_check_mark: | +|nn.ChannelShuffle | :heavy_check_mark: | :heavy_check_mark: | +|nn.ConstantPad1d | :heavy_check_mark: | :heavy_check_mark: | +|nn.ConstantPad2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.ConstantPad3d | :heavy_check_mark: | +|nn.Conv1d | :heavy_check_mark: | :heavy_check_mark: | +|nn.Conv2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.Conv3d | :heavy_check_mark: | +|nn.ConvTranspose1d | :heavy_check_mark: | +|nn.ConvTranspose2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.ConvTranspose3d | :heavy_check_mark: | +|nn.CosineSimilarity | | +|nn.Dropout | | :heavy_check_mark:* | +|nn.Dropout2d | | +|nn.Dropout3d | | +|nn.ELU | :heavy_check_mark: | :heavy_check_mark: | +|nn.Embedding | :heavy_check_mark: | :heavy_check_mark: | +|nn.EmbeddingBag | | +|nn.Flatten | :heavy_check_mark: | +|nn.Fold | | +|nn.FractionalMaxPool2d | | +|nn.FractionalMaxPool3d | | +|nn.GELU | :heavy_check_mark: | :heavy_check_mark: | +|nn.GroupNorm | :heavy_check_mark: | :heavy_check_mark: | +|nn.GRU | :heavy_check_mark: | :heavy_check_mark: | +|nn.GRUCell | | +|nn.Hardshrink | :heavy_check_mark: | +|nn.Hardsigmoid | :heavy_check_mark: | :heavy_check_mark: | +|nn.Hardswish | :heavy_check_mark: | :heavy_check_mark: | +|nn.Hardtanh | :heavy_check_mark: | :heavy_check_mark: | +|nn.Identity | | +|nn.InstanceNorm1d | :heavy_check_mark: | +|nn.InstanceNorm2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.InstanceNorm3d | :heavy_check_mark: | +|nn.LayerNorm | :heavy_check_mark: | :heavy_check_mark: | +|nn.LazyBatchNorm1d | | +|nn.LazyBatchNorm2d | | +|nn.LazyBatchNorm3d | | +|nn.LazyConv1d | | +|nn.LazyConv2d | | +|nn.LazyConv3d | | +|nn.LazyConvTranspose1d | | +|nn.LazyConvTranspose2d | | +|nn.LazyConvTranspose3d | | +|nn.LazyLinear | | +|nn.LeakyReLU | :heavy_check_mark: | :heavy_check_mark: | +|nn.Linear | :heavy_check_mark: | :heavy_check_mark: | +|nn.LocalResponseNorm | :heavy_check_mark: | :heavy_check_mark: | +|nn.LogSigmoid | :heavy_check_mark: | +|nn.LogSoftmax | :heavy_check_mark: | +|nn.LPPool1d | :heavy_check_mark: | +|nn.LPPool2d | :heavy_check_mark: | +|nn.LSTM | :heavy_check_mark: | :heavy_check_mark: | +|nn.LSTMCell | | +|nn.MaxPool1d | :heavy_check_mark: | +|nn.MaxPool2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.MaxPool3d | :heavy_check_mark: | +|nn.MaxUnpool1d | | +|nn.MaxUnpool2d | | +|nn.MaxUnpool3d | | +|nn.Mish | :heavy_check_mark: | :heavy_check_mark: | +|nn.MultiheadAttention | :heavy_check_mark: | :heavy_check_mark:* | +|nn.PairwiseDistance | | +|nn.PixelShuffle | :heavy_check_mark: | :heavy_check_mark: | +|nn.PixelUnshuffle | :heavy_check_mark: | :heavy_check_mark: | +|nn.PReLU | :heavy_check_mark: | :heavy_check_mark: | +|nn.ReflectionPad1d | :heavy_check_mark: | :heavy_check_mark: | +|nn.ReflectionPad2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.ReLU | :heavy_check_mark: | :heavy_check_mark: | +|nn.ReLU6 | :heavy_check_mark: | :heavy_check_mark: | +|nn.ReplicationPad1d | :heavy_check_mark: | :heavy_check_mark: | +|nn.ReplicationPad2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.ReplicationPad3d | :heavy_check_mark: | +|nn.RNN | :heavy_check_mark: | :heavy_check_mark:* | +|nn.RNNBase | | +|nn.RNNCell | | +|nn.RReLU | :heavy_check_mark: | +|nn.SELU | :heavy_check_mark: | :heavy_check_mark: | +|nn.Sigmoid | :heavy_check_mark: | :heavy_check_mark: | +|nn.SiLU | :heavy_check_mark: | :heavy_check_mark: | +|nn.Softmax | :heavy_check_mark: | :heavy_check_mark: | +|nn.Softmax2d | | +|nn.Softmin | :heavy_check_mark: | +|nn.Softplus | :heavy_check_mark: | +|nn.Softshrink | :heavy_check_mark: | +|nn.Softsign | :heavy_check_mark: | +|nn.SyncBatchNorm | | +|nn.Tanh | :heavy_check_mark: | :heavy_check_mark: | +|nn.Tanhshrink | :heavy_check_mark: | +|nn.Threshold | :heavy_check_mark: | +|nn.Transformer | | +|nn.TransformerDecoder | | +|nn.TransformerDecoderLayer | | +|nn.TransformerEncoder | | +|nn.TransformerEncoderLayer | | +|nn.Unflatten | | +|nn.Unfold | | +|nn.Upsample | :heavy_check_mark: | :heavy_check_mark: | +|nn.UpsamplingBilinear2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.UpsamplingNearest2d | :heavy_check_mark: | :heavy_check_mark: | +|nn.ZeroPad2d | :heavy_check_mark: | :heavy_check_mark: | + + +| torch.nn.functional | Is Supported | Export to ncnn | +|---------------------------|----|----| +|F.adaptive_avg_pool1d | :heavy_check_mark: | +|F.adaptive_avg_pool2d | :heavy_check_mark: | :heavy_check_mark: | +|F.adaptive_avg_pool3d | :heavy_check_mark: | +|F.adaptive_max_pool1d | :heavy_check_mark: | +|F.adaptive_max_pool2d | :heavy_check_mark: | :heavy_check_mark: | +|F.adaptive_max_pool3d | :heavy_check_mark: | +|F.affine_grid | :heavy_check_mark: | +|F.alpha_dropout | | +|F.avg_pool1d | :heavy_check_mark: | +|F.avg_pool2d | :heavy_check_mark: | +|F.avg_pool3d | :heavy_check_mark: | +|F.batch_norm | :heavy_check_mark: | :heavy_check_mark: | +|F.bilinear | | +|F.celu | :heavy_check_mark: | +|F.conv1d | :heavy_check_mark: | +|F.conv2d | :heavy_check_mark: | :heavy_check_mark:* | +|F.conv3d | :heavy_check_mark: | +|F.conv_transpose1d | :heavy_check_mark: | +|F.conv_transpose2d | :heavy_check_mark: | +|F.conv_transpose3d | :heavy_check_mark: | +|F.cosine_similarity | | +|F.dropout | | +|F.dropout2d | | +|F.dropout3d | | +|F.elu | :heavy_check_mark: | :heavy_check_mark: | +|F.elu_ | :heavy_check_mark: | :heavy_check_mark: | +|F.embedding | | +|F.embedding_bag | | +|F.feature_alpha_dropout | | +|F.fold | | +|F.fractional_max_pool2d | | +|F.fractional_max_pool3d | | +|F.gelu | :heavy_check_mark: | :heavy_check_mark: | +|F.glu | | +|F.grid_sample | :heavy_check_mark: | +|F.group_norm | :heavy_check_mark: | :heavy_check_mark: | +|F.gumbel_softmax | | +|F.hardshrink | :heavy_check_mark: | +|F.hardsigmoid | :heavy_check_mark: | :heavy_check_mark: | +|F.hardswish | :heavy_check_mark: | :heavy_check_mark: | +|F.hardtanh | :heavy_check_mark: | :heavy_check_mark: | +|F.hardtanh_ | :heavy_check_mark: | :heavy_check_mark: | +|F.instance_norm | :heavy_check_mark: | :heavy_check_mark: | +|F.interpolate | :heavy_check_mark: | :heavy_check_mark: | +|F.layer_norm | :heavy_check_mark: | :heavy_check_mark: | +|F.leaky_relu | :heavy_check_mark: | :heavy_check_mark: | +|F.leaky_relu_ | :heavy_check_mark: | :heavy_check_mark: | +|F.linear | :heavy_check_mark: | :heavy_check_mark:* | +|F.local_response_norm | :heavy_check_mark: | :heavy_check_mark: | +|F.logsigmoid | :heavy_check_mark: | +|F.log_softmax | :heavy_check_mark: | +|F.lp_pool1d | :heavy_check_mark: | +|F.lp_pool2d | :heavy_check_mark: | +|F.max_pool1d | :heavy_check_mark: | +|F.max_pool2d | :heavy_check_mark: | +|F.max_pool3d | :heavy_check_mark: | +|F.max_unpool1d | | +|F.max_unpool2d | | +|F.max_unpool3d | | +|F.mish | :heavy_check_mark: | :heavy_check_mark: | +|F.normalize | :heavy_check_mark: | :heavy_check_mark: | +|F.one_hot | | +|F.pad | :heavy_check_mark: | :heavy_check_mark: | +|F.pairwise_distance | | +|F.pdist | | +|F.pixel_shuffle | :heavy_check_mark: | :heavy_check_mark: | +|F.pixel_unshuffle | :heavy_check_mark: | :heavy_check_mark: | +|F.prelu | :heavy_check_mark: | :heavy_check_mark: | +|F.relu | :heavy_check_mark: | :heavy_check_mark: | +|F.relu_ | :heavy_check_mark: | :heavy_check_mark: | +|F.relu6 | :heavy_check_mark: | :heavy_check_mark: | +|F.rrelu | :heavy_check_mark: | +|F.rrelu_ | :heavy_check_mark: | +|F.selu | :heavy_check_mark: | :heavy_check_mark: | +|F.sigmoid | :heavy_check_mark: | :heavy_check_mark: | +|F.silu | :heavy_check_mark: | :heavy_check_mark: | +|F.softmax | :heavy_check_mark: | :heavy_check_mark: | +|F.softmin | :heavy_check_mark: | +|F.softplus | :heavy_check_mark: | +|F.softshrink | :heavy_check_mark: | +|F.softsign | :heavy_check_mark: | +|F.tanh | :heavy_check_mark: | :heavy_check_mark: | +|F.tanhshrink | :heavy_check_mark: | +|F.threshold | :heavy_check_mark: | +|F.threshold_ | :heavy_check_mark: | +|F.unfold | | +|F.upsample | :heavy_check_mark: | :heavy_check_mark: | +|F.upsample_bilinear | :heavy_check_mark: | :heavy_check_mark: | +|F.upsample_nearest | :heavy_check_mark: | :heavy_check_mark: | diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt new file mode 100644 index 000000000000..61e565515990 --- /dev/null +++ b/tools/pnnx/src/CMakeLists.txt @@ -0,0 +1,378 @@ + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +set(pnnx_pass_level0_SRCS + pass_level0/constant_unpooling.cpp + pass_level0/inline_block.cpp + pass_level0/shape_inference.cpp +) + +set(pnnx_pass_level1_SRCS + pass_level1/nn_AdaptiveAvgPool1d.cpp + pass_level1/nn_AdaptiveAvgPool2d.cpp + pass_level1/nn_AdaptiveAvgPool3d.cpp + pass_level1/nn_AdaptiveMaxPool1d.cpp + pass_level1/nn_AdaptiveMaxPool2d.cpp + pass_level1/nn_AdaptiveMaxPool3d.cpp + pass_level1/nn_AvgPool1d.cpp + pass_level1/nn_AvgPool2d.cpp + pass_level1/nn_AvgPool3d.cpp + pass_level1/nn_BatchNorm1d.cpp + pass_level1/nn_BatchNorm2d.cpp + pass_level1/nn_BatchNorm3d.cpp + pass_level1/nn_CELU.cpp + pass_level1/nn_ChannelShuffle.cpp + pass_level1/nn_ConstantPad1d.cpp + pass_level1/nn_ConstantPad2d.cpp + pass_level1/nn_ConstantPad3d.cpp + pass_level1/nn_Conv1d.cpp + pass_level1/nn_Conv2d.cpp + pass_level1/nn_Conv3d.cpp + pass_level1/nn_ConvTranspose1d.cpp + pass_level1/nn_ConvTranspose2d.cpp + pass_level1/nn_ConvTranspose3d.cpp + pass_level1/nn_Dropout.cpp + pass_level1/nn_ELU.cpp + pass_level1/nn_Embedding.cpp + pass_level1/nn_GELU.cpp + pass_level1/nn_GroupNorm.cpp + pass_level1/nn_GRU.cpp + pass_level1/nn_Hardshrink.cpp + pass_level1/nn_Hardsigmoid.cpp + pass_level1/nn_Hardswish.cpp + pass_level1/nn_Hardtanh.cpp + pass_level1/nn_InstanceNorm1d.cpp + pass_level1/nn_InstanceNorm2d.cpp + pass_level1/nn_InstanceNorm3d.cpp + pass_level1/nn_LayerNorm.cpp + pass_level1/nn_LeakyReLU.cpp + pass_level1/nn_Linear.cpp + pass_level1/nn_LocalResponseNorm.cpp + pass_level1/nn_LogSigmoid.cpp + pass_level1/nn_LogSoftmax.cpp + pass_level1/nn_LPPool1d.cpp + pass_level1/nn_LPPool2d.cpp + pass_level1/nn_LSTM.cpp + pass_level1/nn_MaxPool1d.cpp + pass_level1/nn_MaxPool2d.cpp + pass_level1/nn_MaxPool3d.cpp + pass_level1/nn_maxunpool2d.cpp + pass_level1/nn_Mish.cpp + pass_level1/nn_MultiheadAttention.cpp + pass_level1/nn_PixelShuffle.cpp + pass_level1/nn_PixelUnshuffle.cpp + pass_level1/nn_PReLU.cpp + pass_level1/nn_ReflectionPad1d.cpp + pass_level1/nn_ReflectionPad2d.cpp + pass_level1/nn_ReLU.cpp + pass_level1/nn_ReLU6.cpp + pass_level1/nn_ReplicationPad1d.cpp + pass_level1/nn_ReplicationPad2d.cpp + pass_level1/nn_ReplicationPad3d.cpp + pass_level1/nn_RNN.cpp + pass_level1/nn_RReLU.cpp + pass_level1/nn_SELU.cpp + pass_level1/nn_Sigmoid.cpp + pass_level1/nn_SiLU.cpp + pass_level1/nn_Softmax.cpp + pass_level1/nn_Softmin.cpp + pass_level1/nn_Softplus.cpp + pass_level1/nn_Softshrink.cpp + pass_level1/nn_Softsign.cpp + pass_level1/nn_Tanh.cpp + pass_level1/nn_Tanhshrink.cpp + pass_level1/nn_Threshold.cpp + pass_level1/nn_Upsample.cpp + pass_level1/nn_UpsamplingBilinear2d.cpp + pass_level1/nn_UpsamplingNearest2d.cpp + pass_level1/nn_ZeroPad2d.cpp + + pass_level1/nn_quantized_Conv2d.cpp + pass_level1/nn_quantized_DeQuantize.cpp + pass_level1/nn_quantized_Linear.cpp + pass_level1/nn_quantized_Quantize.cpp +) + +set(pnnx_pass_level2_SRCS + pass_level2/F_adaptive_avg_pool1d.cpp + pass_level2/F_adaptive_avg_pool2d.cpp + pass_level2/F_adaptive_avg_pool3d.cpp + pass_level2/F_adaptive_max_pool1d.cpp + pass_level2/F_adaptive_max_pool2d.cpp + pass_level2/F_adaptive_max_pool3d.cpp + pass_level2/F_affine_grid.cpp + pass_level2/F_avg_pool1d.cpp + pass_level2/F_avg_pool2d.cpp + pass_level2/F_avg_pool3d.cpp + pass_level2/F_batch_norm.cpp + pass_level2/F_celu.cpp + pass_level2/F_conv1d.cpp + pass_level2/F_conv2d.cpp + pass_level2/F_conv3d.cpp + pass_level2/F_conv_transpose1d.cpp + pass_level2/F_conv_transpose2d.cpp + pass_level2/F_conv_transpose3d.cpp + pass_level2/F_elu.cpp + pass_level2/F_gelu.cpp + pass_level2/F_grid_sample.cpp + pass_level2/F_group_norm.cpp + pass_level2/F_hardshrink.cpp + pass_level2/F_hardsigmoid.cpp + pass_level2/F_hardswish.cpp + pass_level2/F_hardtanh.cpp + pass_level2/F_instance_norm.cpp + pass_level2/F_interpolate.cpp + pass_level2/F_layer_norm.cpp + pass_level2/F_leaky_relu.cpp + pass_level2/F_linear.cpp + pass_level2/F_local_response_norm.cpp + pass_level2/F_log_softmax.cpp + pass_level2/F_logsigmoid.cpp + pass_level2/F_lp_pool1d.cpp + pass_level2/F_lp_pool2d.cpp + pass_level2/F_max_pool1d.cpp + pass_level2/F_max_pool2d.cpp + pass_level2/F_max_pool3d.cpp + pass_level2/F_mish.cpp + pass_level2/F_normalize.cpp + pass_level2/F_pad.cpp + pass_level2/F_pixel_shuffle.cpp + pass_level2/F_pixel_unshuffle.cpp + pass_level2/F_prelu.cpp + pass_level2/F_relu.cpp + pass_level2/F_relu6.cpp + pass_level2/F_rrelu.cpp + pass_level2/F_selu.cpp + pass_level2/F_sigmoid.cpp + pass_level2/F_silu.cpp + pass_level2/F_softmax.cpp + pass_level2/F_softmin.cpp + pass_level2/F_softplus.cpp + pass_level2/F_softshrink.cpp + pass_level2/F_softsign.cpp + pass_level2/F_tanh.cpp + pass_level2/F_tanhshrink.cpp + pass_level2/F_threshold.cpp + pass_level2/F_upsample_bilinear.cpp + pass_level2/F_upsample_nearest.cpp + pass_level2/F_upsample.cpp + pass_level2/Tensor_contiguous.cpp + pass_level2/Tensor_new_empty.cpp + pass_level2/Tensor_repeat.cpp + pass_level2/Tensor_reshape.cpp + pass_level2/Tensor_select.cpp + pass_level2/Tensor_slice.cpp + pass_level2/Tensor_view.cpp + pass_level2/torch_cat.cpp + pass_level2/torch_chunk.cpp + pass_level2/torch_clamp.cpp + pass_level2/torch_flatten.cpp + pass_level2/torch_mean.cpp + pass_level2/torch_sum.cpp + pass_level2/torch_split.cpp + pass_level2/torch_squeeze.cpp + pass_level2/torch_permute.cpp + pass_level2/torch_transpose.cpp + pass_level2/torch_unsqueeze.cpp + + pass_level2/nn_quantized_FloatFunctional.cpp +) + +set(pnnx_pass_level3_SRCS + pass_level3/eliminate_tuple_pair.cpp + pass_level3/expand_quantization_modules.cpp + pass_level3/fuse_attribute_expression.cpp + pass_level3/fuse_cat_tensors.cpp + pass_level3/fuse_chunk_split_unpack.cpp + pass_level3/fuse_expression.cpp + pass_level3/fuse_rnn_unpack.cpp +) + +set(pnnx_pass_level4_SRCS + pass_level4/canonicalize.cpp + pass_level4/dead_code_elimination.cpp + pass_level4/fuse_custom_op.cpp +) + +set(pnnx_pass_level5_SRCS + pass_level5/eliminate_slice.cpp + pass_level5/eliminate_view_reshape.cpp + pass_level5/eval_expression.cpp + pass_level5/fuse_channel_shuffle.cpp + pass_level5/fuse_constant_expression.cpp + pass_level5/fuse_conv2d_batchnorm2d.cpp + pass_level5/fuse_convtranspose2d_batchnorm2d.cpp + pass_level5/fuse_contiguous_view.cpp + pass_level5/fuse_linear_batchnorm1d.cpp + pass_level5/fuse_slice_indices.cpp + pass_level5/unroll_rnn_op.cpp +) + +set(pnnx_pass_ncnn_SRCS + pass_ncnn/convert_attribute.cpp + pass_ncnn/convert_custom_op.cpp + pass_ncnn/convert_input.cpp + pass_ncnn/convert_torch_cat.cpp + pass_ncnn/convert_torch_chunk.cpp + pass_ncnn/convert_torch_split.cpp + pass_ncnn/eliminate_output.cpp + pass_ncnn/expand_expression.cpp + pass_ncnn/insert_split.cpp + pass_ncnn/chain_multi_output.cpp + pass_ncnn/solve_batch_index.cpp + + pass_ncnn/eliminate_noop.cpp + pass_ncnn/fuse_convolution_activation.cpp + pass_ncnn/fuse_convolution1d_activation.cpp + pass_ncnn/fuse_convolutiondepthwise_activation.cpp + pass_ncnn/fuse_convolutiondepthwise1d_activation.cpp + pass_ncnn/fuse_deconvolution_activation.cpp + pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp + pass_ncnn/fuse_innerproduct_activation.cpp + + pass_ncnn/F_adaptive_avg_pool2d.cpp + pass_ncnn/F_adaptive_max_pool2d.cpp + pass_ncnn/F_batch_norm.cpp + pass_ncnn/F_conv2d.cpp + pass_ncnn/F_elu.cpp + pass_ncnn/F_gelu.cpp + pass_ncnn/F_group_norm.cpp + pass_ncnn/F_hardsigmoid.cpp + pass_ncnn/F_hardswish.cpp + pass_ncnn/F_hardtanh.cpp + pass_ncnn/F_instance_norm.cpp + pass_ncnn/F_interpolate.cpp + pass_ncnn/F_layer_norm.cpp + pass_ncnn/F_leaky_relu.cpp + pass_ncnn/F_linear.cpp + pass_ncnn/F_local_response_norm.cpp + pass_ncnn/F_mish.cpp + pass_ncnn/F_normalize.cpp + pass_ncnn/F_pad.cpp + pass_ncnn/F_pixel_shuffle.cpp + pass_ncnn/F_pixel_unshuffle.cpp + pass_ncnn/F_prelu.cpp + pass_ncnn/F_relu.cpp + pass_ncnn/F_relu6.cpp + pass_ncnn/F_selu.cpp + pass_ncnn/F_sigmoid.cpp + pass_ncnn/F_silu.cpp + pass_ncnn/F_softmax.cpp + pass_ncnn/F_tanh.cpp + pass_ncnn/F_upsample_bilinear.cpp + pass_ncnn/F_upsample_nearest.cpp + pass_ncnn/F_upsample.cpp + pass_ncnn/nn_AdaptiveAvgPool2d.cpp + pass_ncnn/nn_AdaptiveMaxPool2d.cpp + pass_ncnn/nn_AvgPool2d.cpp + pass_ncnn/nn_BatchNorm1d.cpp + pass_ncnn/nn_BatchNorm2d.cpp + pass_ncnn/nn_ChannelShuffle.cpp + pass_ncnn/nn_ConstantPad1d.cpp + pass_ncnn/nn_ConstantPad2d.cpp + pass_ncnn/nn_Conv1d.cpp + pass_ncnn/nn_Conv2d.cpp + pass_ncnn/nn_ConvTranspose2d.cpp + pass_ncnn/nn_Dropout.cpp + pass_ncnn/nn_ELU.cpp + pass_ncnn/nn_Embedding.cpp + pass_ncnn/nn_GELU.cpp + pass_ncnn/nn_GroupNorm.cpp + pass_ncnn/nn_GRU.cpp + pass_ncnn/nn_Hardsigmoid.cpp + pass_ncnn/nn_Hardswish.cpp + pass_ncnn/nn_Hardtanh.cpp + pass_ncnn/nn_InstanceNorm2d.cpp + pass_ncnn/nn_LayerNorm.cpp + pass_ncnn/nn_LeakyReLU.cpp + pass_ncnn/nn_Linear.cpp + pass_ncnn/nn_LocalResponseNorm.cpp + pass_ncnn/nn_LSTM.cpp + pass_ncnn/nn_MaxPool2d.cpp + pass_ncnn/nn_Mish.cpp + pass_ncnn/nn_MultiheadAttention.cpp + pass_ncnn/nn_PixelShuffle.cpp + pass_ncnn/nn_PixelUnshuffle.cpp + pass_ncnn/nn_PReLU.cpp + pass_ncnn/nn_ReflectionPad1d.cpp + pass_ncnn/nn_ReflectionPad2d.cpp + pass_ncnn/nn_ReLU.cpp + pass_ncnn/nn_ReLU6.cpp + pass_ncnn/nn_ReplicationPad1d.cpp + pass_ncnn/nn_ReplicationPad2d.cpp + pass_ncnn/nn_RNN.cpp + pass_ncnn/nn_SELU.cpp + pass_ncnn/nn_Sigmoid.cpp + pass_ncnn/nn_SiLU.cpp + pass_ncnn/nn_Softmax.cpp + pass_ncnn/nn_Tanh.cpp + pass_ncnn/nn_Upsample.cpp + pass_ncnn/nn_UpsamplingBilinear2d.cpp + pass_ncnn/nn_UpsamplingNearest2d.cpp + pass_ncnn/nn_ZeroPad2d.cpp + pass_ncnn/Tensor_contiguous.cpp + pass_ncnn/Tensor_reshape.cpp + pass_ncnn/Tensor_slice.cpp + pass_ncnn/Tensor_view.cpp + pass_ncnn/torch_clamp.cpp + pass_ncnn/torch_flatten.cpp + pass_ncnn/torch_mean.cpp + pass_ncnn/torch_permute.cpp + pass_ncnn/torch_squeeze.cpp + pass_ncnn/torch_transpose.cpp + pass_ncnn/torch_unsqueeze.cpp +) + +set(pnnx_SRCS + main.cpp + ir.cpp + storezip.cpp + utils.cpp + + pass_level0.cpp + pass_level1.cpp + pass_level2.cpp + pass_level3.cpp + pass_level4.cpp + pass_level5.cpp + + pass_ncnn.cpp + + ${pnnx_pass_level0_SRCS} + ${pnnx_pass_level1_SRCS} + ${pnnx_pass_level2_SRCS} + ${pnnx_pass_level3_SRCS} + ${pnnx_pass_level4_SRCS} + ${pnnx_pass_level5_SRCS} + + ${pnnx_pass_ncnn_SRCS} +) + +if(NOT MSVC) + add_definitions(-Wall -Wextra) +endif() + +add_executable(pnnx ${pnnx_SRCS}) + +if(PNNX_COVERAGE) + target_compile_options(pnnx PUBLIC -coverage -fprofile-arcs -ftest-coverage) + target_link_libraries(pnnx PUBLIC -coverage -lgcov) +endif() + +if(TorchVision_FOUND) + target_link_libraries(pnnx PRIVATE TorchVision::TorchVision) +endif() + +if(WIN32) + target_link_libraries(pnnx PRIVATE ${TORCH_LIBRARIES}) +else() + target_link_libraries(pnnx PRIVATE ${TORCH_LIBRARIES} dl) +endif() + +#set_target_properties(pnnx PROPERTIES COMPILE_FLAGS -fsanitize=address) +#set_target_properties(pnnx PROPERTIES LINK_FLAGS -fsanitize=address) + +set_target_properties(pnnx PROPERTIES INSTALL_RPATH "$ORIGIN/") +set_target_properties(pnnx PROPERTIES MACOSX_RPATH TRUE) + +install(TARGETS pnnx RUNTIME DESTINATION bin) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp new file mode 100644 index 000000000000..b35918f4d1f6 --- /dev/null +++ b/tools/pnnx/src/ir.cpp @@ -0,0 +1,2194 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +#include +#include +#include +#include +#include +#include + +#include + +#include "storezip.h" + +namespace pnnx { + +static const char* type_to_string(int type) +{ + if (type == 1) return "f32"; + if (type == 2) return "f64"; + if (type == 3) return "f16"; + if (type == 4) return "i32"; + if (type == 5) return "i64"; + if (type == 6) return "i16"; + if (type == 7) return "i8"; + if (type == 8) return "u8"; + return "null"; +} + +static const char* type_to_numpy_string(int type) +{ + if (type == 1) return "float32"; + if (type == 2) return "float64"; + if (type == 3) return "float16"; + if (type == 4) return "int32"; + if (type == 5) return "int64"; + if (type == 6) return "int16"; + if (type == 7) return "int8"; + if (type == 8) return "uint8"; + return "null"; +} + +static size_t type_to_elemsize(int type) +{ + if (type == 1) return 4; + if (type == 2) return 8; + if (type == 3) return 2; + if (type == 4) return 4; + if (type == 5) return 8; + if (type == 6) return 2; + if (type == 7) return 1; + if (type == 8) return 1; + return 0; // null +} + +static int string_to_type(const char* s) +{ + if (strcmp(s, "f32") == 0) return 1; + if (strcmp(s, "f64") == 0) return 2; + if (strcmp(s, "f16") == 0) return 3; + if (strcmp(s, "i32") == 0) return 4; + if (strcmp(s, "i64") == 0) return 5; + if (strcmp(s, "i16") == 0) return 6; + if (strcmp(s, "i8") == 0) return 7; + if (strcmp(s, "u8") == 0) return 8; + return 0; // null +} + +int get_at_tensor_type(const at::ScalarType& st) +{ + if (st == c10::ScalarType::Float) return 1; + if (st == c10::ScalarType::Double) return 2; + if (st == c10::ScalarType::Half) return 3; + if (st == c10::ScalarType::Int) return 4; + if (st == c10::ScalarType::QInt32) return 4; + if (st == c10::ScalarType::Long) return 5; + if (st == c10::ScalarType::Short) return 6; + if (st == c10::ScalarType::Char) return 7; + if (st == c10::ScalarType::QInt8) return 7; + if (st == c10::ScalarType::Byte) return 8; + if (st == c10::ScalarType::QUInt8) return 8; + return 0; // unknown type +} + +Parameter::Parameter(const torch::jit::Node* value_node) +{ + type = 0; + + if (value_node->kind() == c10::prim::Constant) + { + if (!value_node->hasAttribute(torch::jit::attr::value)) + { + fprintf(stderr, "no attribute value\n"); + return; + } + + switch (value_node->output()->type()->kind()) + { + case c10::TypeKind::NoneType: + { + type = 0; + break; + } + case c10::TypeKind::BoolType: + { + type = 1; + b = value_node->i(torch::jit::attr::value); + break; + } + case c10::TypeKind::IntType: + { + type = 2; + i = (int)value_node->i(torch::jit::attr::value); + break; + } + case c10::TypeKind::FloatType: + { + type = 3; + f = (float)value_node->f(torch::jit::attr::value); + break; + } + case c10::TypeKind::StringType: + { + type = 4; + s = value_node->s(torch::jit::attr::value); + break; + } + case c10::TypeKind::TensorType: + { + at::Tensor t = value_node->t(torch::jit::attr::value); + + if (t.dim() == 0) + { + if (t.scalar_type() == c10::ScalarType::Long) + { + type = 2; + i = (int)t.item(); + } + else if (t.scalar_type() == c10::ScalarType::Int) + { + type = 2; + i = t.item(); + } + else if (t.scalar_type() == c10::ScalarType::Double) + { + type = 3; + f = (float)t.item(); + } + else if (t.scalar_type() == c10::ScalarType::Float) + { + type = 3; + f = t.item(); + } + else + { + fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = 0\n", value_node->kind().toDisplayString()); + } + } + else + { + const int ndim = (int)t.dim(); + + type = 8; + fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = %d\n", value_node->kind().toDisplayString(), ndim); + } + + break; + } + default: + { + fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); + break; + } + } + } + else if (value_node->kind() == c10::prim::ListConstruct) + { + switch (value_node->output()->type()->cast()->getElementType()->kind()) + { + case c10::TypeKind::IntType: + { + type = 5; + for (const auto& x : value_node->inputs()) + { + ai.push_back((int)x->node()->i(torch::jit::attr::value)); + } + break; + } + case c10::TypeKind::FloatType: + { + type = 6; + for (const auto& x : value_node->inputs()) + { + af.push_back((float)x->node()->f(torch::jit::attr::value)); + } + break; + } + case c10::TypeKind::StringType: + { + type = 7; + for (const auto& x : value_node->inputs()) + { + as.push_back(x->node()->s(torch::jit::attr::value)); + } + break; + } + default: + { + fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); + break; + } + } + } + else + { + fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); + } +} + +Parameter::Parameter(const torch::jit::Value* value) + : Parameter(value->node()) +{ +} + +Attribute::Attribute(const at::Tensor& t) +{ + type = get_at_tensor_type(t.scalar_type()); + + const int ndim = (int)t.dim(); + shape.resize(ndim); + for (int i = 0; i < ndim; i++) + shape[i] = t.size(i); + + if (shape.size() > 0) + { + int size = shape[0]; + for (size_t i = 1; i < shape.size(); i++) + { + size *= shape[i]; + } + + data.resize(size * type_to_elemsize(type)); + memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size()); + } +} + +Attribute::Attribute(const std::initializer_list& _shape, const std::vector& t) +{ + type = 1; + shape = _shape; + + if (shape.size() > 0) + { + int size = shape[0]; + for (size_t i = 1; i < shape.size(); i++) + { + size *= shape[i]; + } + + data.resize(size * type_to_elemsize(type)); + memcpy((void*)data.data(), (const void*)t.data(), data.size()); + } +} + +Parameter Parameter::parse_from_string(const std::string& value) +{ + Parameter p; + p.type = 0; + + if (value == "None" || value == "()" || value == "[]") + { + return p; + } + + if (value == "True" || value == "False") + { + // bool + p.type = 1; + p.b = value == "True"; + return p; + } + + if (value[0] == '(' || value[0] == '[') + { + // list + std::string lc = value.substr(1, value.size() - 2); + std::istringstream lcss(lc); + + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) + { + // string + p.type = 7; + p.as.push_back(elem); + } + else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) + { + // float + p.type = 6; + p.af.push_back(std::stof(elem)); + } + else + { + // integer + p.type = 5; + p.ai.push_back(std::stoi(elem)); + } + } + return p; + } + + if ((value[0] != '-' && (value[0] < '0' || value[0] > '9')) || (value[0] == '-' && (value[1] < '0' || value[1] > '9'))) + { + // string + p.type = 4; + p.s = value; + return p; + } + + if (value.find('.') != std::string::npos || value.find('e') != std::string::npos) + { + // float + p.type = 3; + p.f = std::stof(value); + return p; + } + + // integer + p.type = 2; + p.i = std::stoi(value); + return p; +} + +Graph::Graph() +{ +} + +Graph::~Graph() +{ + for (auto x : ops) + delete x; + + for (auto x : operands) + delete x; + + ops.clear(); + operands.clear(); +} + +Graph::Graph(const Graph& /*rhs*/) +{ +} + +Graph& Graph::operator=(const Graph& /*rhs*/) +{ + return *this; +} + +static void load_parameter(Operator* op, const std::string& key, const std::string& value) +{ + op->params[key] = Parameter::parse_from_string(value); +} + +static void load_input_key(Operator* op, const std::string& key, const std::string& value) +{ + op->inputnames.resize(op->inputs.size()); + + for (size_t i = 0; i < op->inputs.size(); i++) + { + const Operand* oprand = op->inputs[i]; + if (oprand->name == value) + { + op->inputnames[i] = key; + break; + } + } +} + +static void load_shape(Operator* op, const std::string& key, const std::string& value) +{ + Operand* operand = 0; + for (auto r : op->inputs) + { + if (r->name == key) + { + operand = r; + break; + } + } + + if (!operand) + { + for (auto r : op->outputs) + { + if (r->name == key) + { + operand = r; + break; + } + } + } + + if (!operand) + { + fprintf(stderr, "no such operand %s for operator %s\n", key.c_str(), op->name.c_str()); + return; + } + + // type + std::string typestr = value.substr(value.find_last_of(')') + 1); + operand->type = string_to_type(typestr.c_str()); + + // shape + std::string lc = value.substr(1, value.find_last_of(')') - 1); + std::istringstream lcss(lc); + + operand->shape.clear(); + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + if (elem == "?") + { + operand->shape.push_back(-1); + } + else + { + int i = std::stoi(elem); + operand->shape.push_back(i); + } + } +} + +static void load_attribute(Operator* op, const std::string& key, const std::string& value, StoreZipReader& szr) +{ + Attribute& a = op->attrs[key]; + + // type + std::string typestr = value.substr(value.find_last_of(')') + 1); + a.type = string_to_type(typestr.c_str()); + + if (a.type == 0) + return; + + // shape + std::string lc = value.substr(1, value.find_last_of(')') - 1); + std::istringstream lcss(lc); + + a.shape.clear(); + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + int i = std::stoi(elem); + a.shape.push_back(i); + } + + if (a.shape.empty()) + return; + + // data + size_t size = 1; + for (int i : a.shape) + { + size *= i; + } + + size_t bytesize = size * type_to_elemsize(a.type); + + std::string filename = op->name + "." + key; + + size_t filesize = szr.get_file_size(filename); + + if (filesize == 0) + { + // no such file + return; + } + + if (filesize != bytesize) + { + fprintf(stderr, "file size not match expect %lu but got %lu\n", bytesize, filesize); + } + + a.data.resize(bytesize); + szr.read_file(filename, (char*)a.data.data()); +} + +int Graph::load(const std::string& parampath, const std::string& binpath) +{ + std::ifstream is(parampath, std::ios::in | std::ios::binary); + if (!is.good()) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + StoreZipReader szr; + if (szr.open(binpath) != 0) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + int magic = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> magic; + } + + int operator_count = 0; + int operand_count = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> operator_count >> operand_count; + } + + for (int i = 0; i < operator_count; i++) + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + std::string type; + std::string name; + int input_count = 0; + int output_count = 0; + + iss >> type >> name >> input_count >> output_count; + + Operator* op = new_operator(type, name); + + for (int j = 0; j < input_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand* r = get_operand(operand_name); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int j = 0; j < output_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand* r = new_operand(operand_name); + r->producer = op; + op->outputs.push_back(r); + } + + // key=value + while (!iss.eof()) + { + std::string param; + iss >> param; + + std::string key; + std::string value; + std::istringstream pss(param); + std::getline(pss, key, '='); + std::getline(pss, value); + + if (key[0] == '@') + { + // attribute + load_attribute(op, key.substr(1), value, szr); + } + else if (key[0] == '$') + { + // operand input key + load_input_key(op, key.substr(1), value); + } + else if (key[0] == '#') + { + // operand shape + load_shape(op, key.substr(1), value); + } + else + { + // parameter + load_parameter(op, key, value); + } + } + } + + return 0; +} + +int Graph::save(const std::string& parampath, const std::string& binpath) +{ + FILE* paramfp = fopen(parampath.c_str(), "wb"); + if (!paramfp) + { + fprintf(stderr, "fopen %s failed\n", parampath.c_str()); + return -1; + } + + StoreZipWriter szw; + if (szw.open(binpath) != 0) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + // magic + fprintf(paramfp, "7767517\n"); + + // op count and oprand count + fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); + + for (const Operator* op : ops) + { + fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); + + for (const Operand* oprand : op->inputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const Operand* oprand : op->outputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const auto& it : op->params) + { + fprintf(paramfp, " %s=", it.first.c_str()); + + const Parameter& param = it.second; + if (param.type == 0) + { + fprintf(paramfp, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(paramfp, "True"); + else + fprintf(paramfp, "False"); + } + if (param.type == 2) + { + fprintf(paramfp, "%d", param.i); + } + if (param.type == 3) + { + fprintf(paramfp, "%e", param.f); + } + if (param.type == 4) + { + fprintf(paramfp, "%s", param.s.c_str()); + } + if (param.type == 5) + { + fprintf(paramfp, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(paramfp, "%d", param.ai[i]); + if (i + 1 != param.ai.size()) + fprintf(paramfp, ","); + } + fprintf(paramfp, ")"); + } + if (param.type == 6) + { + fprintf(paramfp, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(paramfp, "%e", param.af[i]); + if (i + 1 != param.af.size()) + fprintf(paramfp, ","); + } + fprintf(paramfp, ")"); + } + if (param.type == 7) + { + fprintf(paramfp, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + fprintf(paramfp, "%s", param.as[i].c_str()); + if (i + 1 != param.as.size()) + fprintf(paramfp, ","); + } + fprintf(paramfp, ")"); + } + } + + for (const auto& it : op->attrs) + { + fprintf(paramfp, " @%s=", it.first.c_str()); + + const Attribute& attr = it.second; + fprintf(paramfp, "("); + for (int i = 0; i < (int)attr.shape.size() - 1; i++) + { + fprintf(paramfp, "%d,", attr.shape[i]); + } + if (attr.shape.size() > 0) + fprintf(paramfp, "%d", attr.shape[attr.shape.size() - 1]); + fprintf(paramfp, ")"); + + fprintf(paramfp, type_to_string(attr.type)); + + std::string filename = op->name + "." + it.first; + szw.write_file(filename, attr.data.data(), attr.data.size()); + } + + if (op->inputnames.size() == op->inputs.size()) + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (op->inputnames[i].empty()) + continue; + + const Operand* oprand = op->inputs[i]; + fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); + } + } + + for (const Operand* oprand : op->inputs) + { + if (oprand->shape.empty()) + continue; + + fprintf(paramfp, " #%s=", oprand->name.c_str()); + + fprintf(paramfp, "("); + for (int i = 0; i < (int)oprand->shape.size() - 1; i++) + { + if (oprand->shape[i] == -1) + fprintf(paramfp, "?,"); + else + fprintf(paramfp, "%d,", oprand->shape[i]); + } + if (oprand->shape.size() > 0) + { + if (oprand->shape[oprand->shape.size() - 1] == -1) + fprintf(paramfp, "?"); + else + fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); + } + fprintf(paramfp, ")"); + + fprintf(paramfp, type_to_string(oprand->type)); + } + + for (const Operand* oprand : op->outputs) + { + if (oprand->shape.empty()) + continue; + + fprintf(paramfp, " #%s=", oprand->name.c_str()); + + fprintf(paramfp, "("); + for (int i = 0; i < (int)oprand->shape.size() - 1; i++) + { + if (oprand->shape[i] == -1) + fprintf(paramfp, "?,"); + else + fprintf(paramfp, "%d,", oprand->shape[i]); + } + if (oprand->shape.size() > 0) + { + if (oprand->shape[oprand->shape.size() - 1] == -1) + fprintf(paramfp, "?"); + else + fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); + } + fprintf(paramfp, ")"); + + fprintf(paramfp, type_to_string(oprand->type)); + } + + fprintf(paramfp, "\n"); + } + + fclose(paramfp); + + return 0; +} + +static std::string sanitize_identifier(const std::string& s) +{ + std::string ss = s; + for (size_t i = 0; i < ss.size(); i++) + { + if (ss[i] == '.' || ss[i] == ':') + ss[i] = '_'; + } + + return ss; +} + +static std::string expand_expression(const Operator* op) +{ + std::string expr = op->params.at("expr").s; + + // split into tokens + std::vector tokens; + { + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '[') // list + { + t += ch; + tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + } + + // scan and stack + std::stack exprstack; + for (int i = (int)tokens.size() - 1; i >= 0; i--) + { + const std::string& t = tokens[i]; + + if (t == "size") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = a + ".size(" + b + ")"; + exprstack.push(r); + } + else if (t == "int" || t == "sqrt" || t == "rsqrt" || t == "neg") + { + std::string unaryop; + if (t == "int") unaryop = "int"; + if (t == "sqrt") unaryop = "torch.sqrt"; + if (t == "rsqrt") unaryop = "torch.rsqrt"; + if (t == "neg") unaryop = "torch.neg"; + + std::string a = exprstack.top(); + exprstack.pop(); + + std::string r = unaryop + "(" + a + ")"; + exprstack.push(r); + } + else if (t == "pow") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = a + ".pow(" + b + ")"; + exprstack.push(r); + } + else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide") + { + std::string binaryop; + if (t == "add") binaryop = "+"; + if (t == "sub") binaryop = "-"; + if (t == "mul") binaryop = "*"; + if (t == "div") binaryop = "/"; + if (t == "floor_divide") binaryop = "//"; + + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = std::string("(") + a + " " + binaryop + " " + b + ")"; + exprstack.push(r); + } + else if (t == "[") // list + { + std::vector elements; + while (!exprstack.empty()) + { + std::string a = exprstack.top(); + exprstack.pop(); + + elements.push_back(a); + } + + std::string r = "["; + for (int j = 0; j < (int)elements.size() - 1; j++) + { + r += elements[j]; + if (j + 1 != (int)elements.size()) + r += ", "; + } + if (!elements.empty()) + { + r += elements[elements.size() - 1]; + } + r += "]"; + + exprstack.push(r); + } + else if (t[0] == '@') + { + int input_index = std::stoi(t.substr(1)); + std::string varid = std::string("v_") + sanitize_identifier(op->inputs[input_index]->name); + exprstack.push(varid); + } + else + { + // literal + exprstack.push(t); + } + } + + std::string r = exprstack.top(); + exprstack.pop(); + + return r; +} + +static std::string make_slice_expression(const Operator* op) +{ + for (size_t j = 0; j < op->inputnames.size(); j++) + { + fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str()); + } + + std::vector dims = op->params.at("dims").ai; + + std::string r; + + int last_dim = -1; + const int ndim = (int)dims.size(); + for (int i = 0; i < ndim; i++) + { + int dim = dims[i]; + for (int j = last_dim + 1; j < dim; j++) + { + r += ":,"; + } + last_dim = dim; + + if (op->params.find("starts") != op->params.end()) + { + std::vector starts = op->params.at("starts").ai; + int start = starts[i]; + + if (start != 0) + r += std::to_string(start); + } + else + { + fprintf(stderr, "find start\n"); + // find start + for (size_t j = 0; j < op->inputnames.size(); j++) + { + if (op->inputnames[j] == "start") + { + r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); + + fprintf(stderr, "find start %s\n", op->inputs[j]->name.c_str()); + break; + } + } + } + + r += ':'; + + if (op->params.find("ends") != op->params.end()) + { + std::vector ends = op->params.at("ends").ai; + int end = ends[i]; + if (end != -1) + r += std::to_string(end); + } + else + { + // find end + for (size_t j = 0; j < op->inputnames.size(); j++) + { + if (op->inputnames[j] == "end") + { + r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); + break; + } + } + } + + if (op->params.find("steps") != op->params.end()) + { + std::vector steps = op->params.at("steps").ai; + int step = steps[i]; + if (step != 1) + { + r += ':'; + r += std::to_string(step); + } + } + else + { + // find step + for (size_t j = 0; j < op->inputnames.size(); j++) + { + if (op->inputnames[j] == "step") + { + r += ':'; + r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); + break; + } + } + } + + if (i + 1 != ndim) + r += ','; + } + + return r; +} + +int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) +{ + FILE* pyfp = fopen(pypath.c_str(), "wb"); + if (!pyfp) + { + fprintf(stderr, "fopen %s failed\n", pypath.c_str()); + return -1; + } + + fprintf(pyfp, "import os\n"); + fprintf(pyfp, "import numpy as np\n"); + fprintf(pyfp, "import tempfile, zipfile\n"); + fprintf(pyfp, "import torch\n"); + fprintf(pyfp, "import torch.nn as nn\n"); + fprintf(pyfp, "import torch.nn.functional as F\n"); + + fprintf(pyfp, "\n"); + + fprintf(pyfp, "class Model(nn.Module):\n"); + fprintf(pyfp, " def __init__(self):\n"); + fprintf(pyfp, " super(Model, self).__init__()\n"); + + fprintf(pyfp, "\n"); + + // module + { + for (const Operator* op : ops) + { + if (op->type.substr(0, 3) != "nn.") + continue; + + fprintf(pyfp, " self.%s = %s(", sanitize_identifier(op->name).c_str(), op->type.c_str()); + + int param_count = op->params.size(); + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + param_count -= 2; // ignore scale and zero_point + } + + int param_index = 0; + for (const auto& it : op->params) + { + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + if (it.first == "scale" || it.first == "zero_point") + continue; + } + + fprintf(pyfp, "%s=", it.first.c_str()); + + const Parameter& param = it.second; + if (param.type == 0) + { + fprintf(pyfp, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(pyfp, "True"); + else + fprintf(pyfp, "False"); + } + if (param.type == 2) + { + fprintf(pyfp, "%d", param.i); + } + if (param.type == 3) + { + fprintf(pyfp, "%f", param.f); + } + if (param.type == 4) + { + if (param.s.substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.s.c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.s.c_str()); + } + } + if (param.type == 5) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(pyfp, "%d", param.ai[i]); + if (i + 1 != param.ai.size() || param.ai.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 6) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(pyfp, "%f", param.af[i]); + if (i + 1 != param.af.size() || param.af.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 7) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + if (param.as[i].substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.as[i].c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.as[i].c_str()); + } + if (i + 1 != param.as.size() || param.as.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + + param_index++; + if (param_index != param_count) + fprintf(pyfp, ", "); + } + + fprintf(pyfp, ")\n"); + } + } + + fprintf(pyfp, "\n"); + + // load weights + { + fprintf(pyfp, " archive = zipfile.ZipFile('%s', 'r')\n", pnnxbinpath.c_str()); + + for (const Operator* op : ops) + { + if (op->type.substr(0, 3) != "nn.") + continue; + + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + for (const auto& it : op->attrs) + { + if (it.first == "weight" || it.first == "bias") + { + fprintf(pyfp, " self_%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + else + { + // unknown attr + continue; + } + + const Attribute& attr = it.second; + for (size_t i = 0; i < attr.shape.size(); i++) + { + fprintf(pyfp, "%d", attr.shape[i]); + if (i + 1 != attr.shape.size()) + fprintf(pyfp, ","); + } + + fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); + } + + fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str()); + fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), op->params.at("scale").f); + fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), op->params.at("zero_point").i); + + continue; + } + + for (const auto& it : op->attrs) + { + if (it.first == "running_mean" || it.first == "running_var") + { + fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + else + { + fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + + const Attribute& attr = it.second; + for (size_t i = 0; i < attr.shape.size(); i++) + { + fprintf(pyfp, "%d", attr.shape[i]); + if (i + 1 != attr.shape.size()) + fprintf(pyfp, ","); + } + + fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); + } + } + + fprintf(pyfp, " archive.close()\n"); + } + + fprintf(pyfp, "\n"); + + // utility function + { + fprintf(pyfp, " def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, requires_grad=True):\n"); + fprintf(pyfp, " return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), requires_grad)\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):\n"); + fprintf(pyfp, " _, tmppath = tempfile.mkstemp()\n"); + fprintf(pyfp, " tmpf = open(tmppath, 'wb')\n"); + fprintf(pyfp, " with archive.open(key) as keyfile:\n"); + fprintf(pyfp, " tmpf.write(keyfile.read())\n"); + fprintf(pyfp, " tmpf.close()\n"); + fprintf(pyfp, " m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()\n"); + fprintf(pyfp, " os.remove(tmppath)\n"); + fprintf(pyfp, " return torch.from_numpy(m)\n"); + } + + fprintf(pyfp, "\n"); + + // def forward + { + fprintf(pyfp, " def forward(self"); + + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; + + fprintf(pyfp, ", v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); + } + + fprintf(pyfp, "):\n"); + } + + // forward body + { + for (const Operator* op : ops) + { + if (op->type == "pnnx.Input" || op->type == "pnnx.Output") + continue; + + fprintf(pyfp, " "); + + if (op->type == "pnnx.Expression") + { + // expr + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + std::string expanded_expr = expand_expression(op); + fprintf(pyfp, " = %s\n", expanded_expr.c_str()); + } + else if (op->type == "Tensor.slice") + { + // slice expr + std::string slice_expr = make_slice_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str()); + } + else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") + { + // view reshape + fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + if (op->inputs.size() == 2) + { + fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); + } + else + { + const std::vector& shape = op->params.at("shape").ai; + for (size_t i = 0; i < shape.size(); i++) + { + fprintf(pyfp, "%d", shape[i]); + if (i + 1 != shape.size()) + fprintf(pyfp, ", "); + } + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "Tensor.repeat") + { + // view reshape + fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + if (op->inputs.size() == 2) + { + fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); + } + else + { + const std::vector& sizes = op->params.at("sizes").ai; + for (size_t i = 0; i < sizes.size(); i++) + { + fprintf(pyfp, "%d", sizes[i]); + if (i + 1 != sizes.size()) + fprintf(pyfp, ", "); + } + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "torch.cat") + { + // cat + fprintf(pyfp, "v_%s = torch.cat(", sanitize_identifier(op->outputs[0]->name).c_str()); + if (op->inputs.size() == 1) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, "("); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")"); + } + fprintf(pyfp, ", dim=%d", op->params.at("dim").i); + fprintf(pyfp, ")\n"); + } + else if (op->type == "prim::TupleUnpack") + { + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else if (op->type == "prim::TupleConstruct") + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); + fprintf(pyfp, " = ("); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "prim::ListUnpack") + { + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else if (op->type == "prim::ListConstruct") + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); + fprintf(pyfp, " = ["); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "]\n"); + } + else if (op->type == "nn.LSTM") + { + if (op->outputs.size() == 1) + { + fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, "v_%s, (v_%s, v_%s)", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str(), sanitize_identifier(op->outputs[2]->name).c_str()); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + if (op->inputs.size() == 3) + { + fprintf(pyfp, ", (v_%s, v_%s)", sanitize_identifier(op->inputs[1]->name).c_str(), sanitize_identifier(op->inputs[2]->name).c_str()); + } + fprintf(pyfp, ")\n"); + } + else if (op->type.substr(0, 3) == "nn.") + { + // self.xxx() + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + } + else if (op->type.find("::") != std::string::npos || op->type.find(".") != std::string::npos) + { + // direct + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + + if (op->type.substr(0, 7) == "Tensor.") + { + fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + } + else + { + fprintf(pyfp, " = %s(", op->type.c_str()); + + if (op->inputnames.size() == op->inputs.size()) + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (!op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + } + else + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + } + } + + int i = 0; + for (const auto& it : op->params) + { + if (op->type.substr(0, 7) == "Tensor." && i == 0) + { + fprintf(pyfp, "%s=", it.first.c_str()); + } + else + { + fprintf(pyfp, ", %s=", it.first.c_str()); + } + + i++; + + const Parameter& param = it.second; + if (param.type == 0) + { + fprintf(pyfp, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(pyfp, "True"); + else + fprintf(pyfp, "False"); + } + if (param.type == 2) + { + fprintf(pyfp, "%d", param.i); + } + if (param.type == 3) + { + fprintf(pyfp, "%f", param.f); + } + if (param.type == 4) + { + if (param.s.substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.s.c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.s.c_str()); + } + } + if (param.type == 5) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(pyfp, "%d", param.ai[i]); + if (i + 1 != param.ai.size() || param.ai.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 6) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(pyfp, "%f", param.af[i]); + if (i + 1 != param.af.size() || param.af.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 7) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + if (param.as[i].substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.as[i].c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.as[i].c_str()); + } + if (i + 1 != param.as.size() || param.as.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + } + + fprintf(pyfp, ")\n"); + } + else + { + fprintf(stderr, "todo %s\n", op->type.c_str()); + } + } + } + + // return + { + fprintf(pyfp, " return "); + + int output_count = 0; + { + for (const Operator* op : ops) + { + if (op->type == "pnnx.Output") + output_count++; + } + } + + int output_index = 0; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Output") + continue; + + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + if (output_index + 1 != output_count) + fprintf(pyfp, ", "); + + output_index++; + } + + fprintf(pyfp, "\n"); + } + + fprintf(pyfp, "\n"); + + // export torchscript + { + fprintf(pyfp, "def export_torchscript():\n"); + fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.eval()\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + std::vector input_names; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; + + const Operand* r = op->outputs[0]; + std::string input_name = std::string("v_") + sanitize_identifier(r->name); + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + + input_names.push_back(input_name); + } + + fprintf(pyfp, "\n"); + + if (input_names.size() == 1) + { + fprintf(pyfp, " mod = torch.jit.trace(net, %s)\n", input_names[0].c_str()); + } + else + { + fprintf(pyfp, " mod = torch.jit.trace(net, ("); + + for (size_t i = 0; i < input_names.size(); i++) + { + fprintf(pyfp, "%s", input_names[i].c_str()); + if (i + 1 != input_names.size()) + fprintf(pyfp, ", "); + } + + fprintf(pyfp, "))\n"); + } + + fprintf(pyfp, " mod.save(\"%s.pt\")\n", pypath.c_str()); + } + + fprintf(pyfp, "\n"); + + // test inference + { + fprintf(pyfp, "def test_inference():\n"); + fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.eval()\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + std::vector input_names; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; + + const Operand* r = op->outputs[0]; + std::string input_name = std::string("v_") + sanitize_identifier(r->name); + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + + input_names.push_back(input_name); + } + + fprintf(pyfp, "\n"); + + if (input_names.size() == 1) + { + fprintf(pyfp, " return net(%s)\n", input_names[0].c_str()); + } + else + { + fprintf(pyfp, " return net("); + + for (size_t i = 0; i < input_names.size(); i++) + { + fprintf(pyfp, "%s", input_names[i].c_str()); + if (i + 1 != input_names.size()) + fprintf(pyfp, ", "); + } + + fprintf(pyfp, ")\n"); + } + } + + fclose(pyfp); + + return 0; +} + +static bool string_is_positive_integer(const std::string& t) +{ + for (size_t i = 0; i < t.size(); i++) + { + if (t[i] < '0' || t[i] > '9') + return false; + } + + return true; +} + +int Graph::ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath) +{ + FILE* paramfp = fopen(parampath.c_str(), "wb"); + if (!paramfp) + { + fprintf(stderr, "fopen %s failed\n", parampath.c_str()); + return -1; + } + + FILE* binfp = fopen(binpath.c_str(), "wb"); + if (!binfp) + { + fprintf(stderr, "fopen %s failed\n", binpath.c_str()); + fclose(paramfp); + return -1; + } + + // magic + fprintf(paramfp, "7767517\n"); + + // op count and oprand count + fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); + + for (const Operator* op : ops) + { + fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); + + for (const Operand* oprand : op->inputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const Operand* oprand : op->outputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const auto& it : op->params) + { + const Parameter& param = it.second; + + if (!string_is_positive_integer(it.first)) + { + fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str()); + + if (param.type == 0) + { + fprintf(stderr, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(stderr, "True"); + else + fprintf(stderr, "False"); + } + if (param.type == 2) + { + fprintf(stderr, "%d", param.i); + } + if (param.type == 3) + { + fprintf(stderr, "%e", param.f); + } + if (param.type == 4) + { + fprintf(stderr, "%s", param.s.c_str()); + } + if (param.type == 5) + { + fprintf(stderr, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(stderr, "%d", param.ai[i]); + if (i + 1 != param.ai.size()) + fprintf(stderr, ","); + } + fprintf(stderr, ")"); + } + if (param.type == 6) + { + fprintf(stderr, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(stderr, "%e", param.af[i]); + if (i + 1 != param.af.size()) + fprintf(stderr, ","); + } + fprintf(stderr, ")"); + } + if (param.type == 7) + { + fprintf(stderr, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + fprintf(stderr, "%s", param.as[i].c_str()); + if (i + 1 != param.as.size()) + fprintf(stderr, ","); + } + fprintf(stderr, ")"); + } + fprintf(stderr, "\n"); + + continue; + } + + const int idkey = std::stoi(it.first); + if (param.type == 2) + { + fprintf(paramfp, " %d=%d", idkey, param.i); + } + if (param.type == 3) + { + fprintf(paramfp, " %d=%e", idkey, param.f); + } + if (param.type == 5) + { + const int array_size = (int)param.ai.size(); + fprintf(paramfp, " %d=%d", -23300 - idkey, array_size); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(paramfp, ",%d", param.ai[i]); + } + } + if (param.type == 6) + { + const int array_size = (int)param.af.size(); + fprintf(paramfp, " %d=%d", -23300 - idkey, array_size); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(paramfp, ",%e", param.af[i]); + } + } + } + + for (const auto& it : op->attrs) + { + // fprintf(paramfp, " @%s=", it.first.c_str()); + + const Attribute& attr = it.second; + + fwrite(attr.data.data(), attr.data.size(), 1, binfp); + } + + // if (op->inputnames.size() == op->inputs.size()) + // { + // for (size_t i = 0; i < op->inputs.size(); i++) + // { + // const Operand* oprand = op->inputs[i]; + // fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); + // } + // } + + // for (const Operand* oprand : op->outputs) + // { + // if (oprand->params.find("__batch_index") == oprand->params.end()) + // continue; + // + // const int batch_index = oprand->params.at("__batch_index").i; + // + // fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index); + // } + + // for (const Operand* oprand : op->outputs) + // { + // if (oprand->shape.empty()) + // continue; + // + // fprintf(paramfp, " #%s=", oprand->name.c_str()); + // + // fprintf(paramfp, "("); + // for (int64_t i = 0; i < oprand->shape.size() - 1; i++) + // { + // fprintf(paramfp, "%d,", oprand->shape[i]); + // } + // if (oprand->shape.size() > 0) + // fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); + // fprintf(paramfp, ")"); + // + // fprintf(paramfp, type_to_string(oprand->type)); + // } + + fprintf(paramfp, "\n"); + } + + fclose(paramfp); + fclose(binfp); + + FILE* pyfp = fopen(pypath.c_str(), "wb"); + if (!pyfp) + { + fprintf(stderr, "fopen %s failed\n", pypath.c_str()); + return -1; + } + + fprintf(pyfp, "import numpy as np\n"); + fprintf(pyfp, "import ncnn\n"); + fprintf(pyfp, "import torch\n"); + + fprintf(pyfp, "\n"); + + // test inference + { + fprintf(pyfp, "def test_inference():\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + for (const Operator* op : ops) + { + if (op->type != "Input") + continue; + + const Operand* r = op->outputs[0]; + std::string input_name = r->name; + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + } + + fprintf(pyfp, " out = []\n"); + fprintf(pyfp, "\n"); + + fprintf(pyfp, " with ncnn.Net() as net:\n"); + fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str()); + fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str()); + fprintf(pyfp, " outcount = len(net.output_names())\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " with net.create_extractor() as ex:\n"); + + for (const Operator* op : ops) + { + if (op->type != "Input") + continue; + + const Operand* r = op->outputs[0]; + std::string input_name = r->name; + fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(0).numpy()).clone())\n", input_name.c_str(), input_name.c_str()); + } + + fprintf(pyfp, "\n"); + + fprintf(pyfp, " for i in range(outcount):\n"); + fprintf(pyfp, " _, outi = ex.extract(\"out\" + str(i))\n"); + fprintf(pyfp, " out.append(torch.from_numpy(np.array(outi)).unsqueeze(0))\n"); + + fprintf(pyfp, "\n"); + + fprintf(pyfp, " if len(out) == 1:\n"); + fprintf(pyfp, " return out[0]\n"); + fprintf(pyfp, " else:\n"); + fprintf(pyfp, " return tuple(out)\n"); + } + + fclose(pyfp); + + return 0; +} + +int Graph::parse(const std::string& param) +{ + std::istringstream is(param); + if (!is.good()) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + int magic = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> magic; + } + + int operator_count = 0; + int operand_count = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> operator_count >> operand_count; + } + + for (int i = 0; i < operator_count; i++) + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + std::string type; + std::string name; + int input_count = 0; + int output_count = 0; + + iss >> type >> name >> input_count >> output_count; + + Operator* op = new_operator(type, name); + + for (int j = 0; j < input_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand* r = get_operand(operand_name); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int j = 0; j < output_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand* r = new_operand(operand_name); + r->producer = op; + op->outputs.push_back(r); + } + + // key=value + while (!iss.eof()) + { + std::string param; + iss >> param; + + std::string key; + std::string value; + std::istringstream pss(param); + std::getline(pss, key, '='); + std::getline(pss, value); + + if (key[0] == '@') + { + // attribute + // load_attribute(op, key.substr(1), value, szr); + } + else if (key[0] == '$') + { + // operand input key + // load_input_key(op, key.substr(1), value); + } + else if (key[0] == '#') + { + // operand shape + load_shape(op, key.substr(1), value); + } + else + { + // parameter + load_parameter(op, key, value); + } + } + } + + return 0; +} + +void Operand::remove_consumer(const Operator* c) +{ + auto it = std::find(consumers.begin(), consumers.end(), c); + consumers.erase(it); +} + +Operator* Graph::new_operator(const std::string& type, const std::string& name) +{ + Operator* op = new Operator; + op->type = type; + op->name = name; + ops.push_back(op); + return op; +} + +Operator* Graph::new_operator_before(const std::string& type, const std::string& name, const Operator* cur) +{ + Operator* op = new Operator; + op->type = type; + op->name = name; + ops.insert(std::find(ops.begin(), ops.end(), cur), op); + return op; +} + +Operand* Graph::new_operand(const torch::jit::Value* v) +{ + Operand* r = new Operand; + r->name = v->debugName(); + + auto pt = v->type()->cast(); + if (pt) + { + if (pt->scalarType().has_value() && pt->dim().has_value()) + { + r->type = get_at_tensor_type(pt->scalarType().value()); + const int ndim = (int)pt->dim().value(); + r->shape.resize(ndim); + for (int i = 0; i < ndim; i++) + { + if (pt->sizes()[i].has_value()) + r->shape[i] = (int)pt->sizes()[i].value(); + else + r->shape[i] = -1; + } + } + } + + operands.push_back(r); + return r; +} + +Operand* Graph::new_operand(const std::string& name) +{ + Operand* r = new Operand; + r->name = name; + operands.push_back(r); + return r; +} + +Operand* Graph::get_operand(const std::string& name) +{ + for (Operand* r : operands) + { + if (r->name == name) + return r; + } + + return 0; +} + +} // namespace pnnx diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h new file mode 100644 index 000000000000..bce3c36e2bce --- /dev/null +++ b/tools/pnnx/src/ir.h @@ -0,0 +1,233 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_IR_H +#define PNNX_IR_H + +#include +#include +#include +#include + +namespace torch { +namespace jit { +struct Value; +struct Node; +} // namespace jit +} // namespace torch +namespace at { +class Tensor; +} + +namespace pnnx { + +class Parameter +{ +public: + Parameter() + : type(0) + { + } + Parameter(bool _b) + : type(1), b(_b) + { + } + Parameter(int _i) + : type(2), i(_i) + { + } + Parameter(long _l) + : type(2), i(_l) + { + } + Parameter(long long _l) + : type(2), i(_l) + { + } + Parameter(float _f) + : type(3), f(_f) + { + } + Parameter(double _d) + : type(3), f(_d) + { + } + Parameter(const char* _s) + : type(4), s(_s) + { + } + Parameter(const std::string& _s) + : type(4), s(_s) + { + } + Parameter(const std::initializer_list& _ai) + : type(5), ai(_ai) + { + } + Parameter(const std::initializer_list& _ai) + : type(5) + { + for (const auto& x : _ai) + ai.push_back((int)x); + } + Parameter(const std::vector& _ai) + : type(5), ai(_ai) + { + } + Parameter(const std::initializer_list& _af) + : type(6), af(_af) + { + } + Parameter(const std::initializer_list& _af) + : type(6) + { + for (const auto& x : _af) + af.push_back((float)x); + } + Parameter(const std::vector& _af) + : type(6), af(_af) + { + } + Parameter(const std::initializer_list& _as) + : type(7) + { + for (const auto& x : _as) + as.push_back(std::string(x)); + } + Parameter(const std::initializer_list& _as) + : type(7), as(_as) + { + } + Parameter(const std::vector& _as) + : type(7), as(_as) + { + } + + Parameter(const torch::jit::Node* value_node); + Parameter(const torch::jit::Value* value); + + static Parameter parse_from_string(const std::string& value); + + // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others + int type; + + // value + bool b; + int i; + float f; + std::string s; + std::vector ai; + std::vector af; + std::vector as; +}; + +class Attribute +{ +public: + Attribute() + : type(0) + { + } + + Attribute(const at::Tensor& t); + + Attribute(const std::initializer_list& shape, const std::vector& t); + + // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 + int type; + std::vector shape; + + std::vector data; +}; + +class Operator; +class Operand +{ +public: + void remove_consumer(const Operator* c); + + std::string name; + + Operator* producer; + std::vector consumers; + + // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 + int type; + std::vector shape; + + std::map params; + +private: + friend class Graph; + Operand() + { + } +}; + +class Operator +{ +public: + std::string type; + std::string name; + + std::vector inputs; + std::vector outputs; + + std::vector inputnames; + std::map params; + std::map attrs; + +private: + friend class Graph; + Operator() + { + } +}; + +class Graph +{ +public: + Graph(); + ~Graph(); + + int load(const std::string& parampath, const std::string& binpath); + int save(const std::string& parampath, const std::string& binpath); + + int python(const std::string& pypath, const std::string& binpath); + + int ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath); + + int parse(const std::string& param); + + Operator* new_operator(const std::string& type, const std::string& name); + + Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur); + + Operand* new_operand(const torch::jit::Value* v); + + Operand* new_operand(const std::string& name); + + Operand* get_operand(const std::string& name); + + std::vector ops; + std::vector operands; + +private: + Graph(const Graph& rhs); + Graph& operator=(const Graph& rhs); +}; + +} // namespace pnnx + +#endif // PNNX_IR_H diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp new file mode 100644 index 000000000000..b87d8ea256d2 --- /dev/null +++ b/tools/pnnx/src/main.cpp @@ -0,0 +1,359 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include + +#if _WIN32 +#include +#else +#include +#endif + +#include +#include + +#include + +#include "ir.h" +#include "pass_level0.h" +#include "pass_level1.h" +#include "pass_level2.h" +#include "pass_level3.h" +#include "pass_level4.h" +#include "pass_level5.h" + +#include "pass_ncnn.h" + +static std::string get_basename(const std::string& path) +{ + return path.substr(0, path.find_last_of('.')); +} + +static std::vector parse_comma_string_array_list(char* s) +{ + std::vector as; + + char* pch = strtok(s, ","); + while (pch != NULL) + { + as.push_back(std::string(pch)); + + pch = strtok(NULL, ","); + } + + return as; +} + +static std::vector > parse_comma_int_array_list(char* s) +{ + std::vector > aai; + + char* pch = strtok(s, "[]"); + while (pch != NULL) + { + // parse a,b,c + int v; + int nconsumed = 0; + int nscan = sscanf(pch, "%d%n", &v, &nconsumed); + if (nscan == 1) + { + // ok we get array + pch += nconsumed; + + std::vector ai; + ai.push_back(v); + + nscan = sscanf(pch, ",%d%n", &v, &nconsumed); + while (nscan == 1) + { + pch += nconsumed; + + ai.push_back(v); + + nscan = sscanf(pch, ",%d%n", &v, &nconsumed); + } + + // array end + aai.push_back(ai); + } + + pch = strtok(NULL, "[]"); + } + + return aai; +} + +static void print_int64_array_list(const std::vector >& list) +{ + for (size_t i = 0; i < list.size(); i++) + { + const std::vector& array = list[i]; + fprintf(stderr, "["); + for (size_t j = 0; j < array.size(); j++) + { + fprintf(stderr, "%ld", array[j]); + if (j != array.size() - 1) + fprintf(stderr, ","); + } + fprintf(stderr, "]"); + if (i != list.size() - 1) + fprintf(stderr, ","); + } +} + +static void print_string_list(const std::vector& list) +{ + for (size_t i = 0; i < list.size(); i++) + { + fprintf(stderr, "%s", list[i].c_str()); + if (i + 1 != list.size()) + fprintf(stderr, ","); + } +} + +static void show_usage() +{ + fprintf(stderr, "Usage: pnnx [model.pt] [(key=value)...]\n"); + fprintf(stderr, " pnnxparam=model.pnnx.param\n"); + fprintf(stderr, " pnnxbin=model.pnnx.bin\n"); + fprintf(stderr, " pnnxpy=model_pnnx.py\n"); + fprintf(stderr, " ncnnparam=model.ncnn.param\n"); + fprintf(stderr, " ncnnbin=model.ncnn.bin\n"); + fprintf(stderr, " ncnnpy=model_ncnn.py\n"); + fprintf(stderr, " optlevel=2\n"); + fprintf(stderr, " device=cpu/gpu\n"); + fprintf(stderr, " inputshape=[1,3,224,224],...\n"); + fprintf(stderr, " inputshape2=[1,3,320,320],...\n"); +#if _WIN32 + fprintf(stderr, " customop=C:\\Users\\nihui\\AppData\\Local\\torch_extensions\\torch_extensions\\Cache\\fused\\fused.dll,...\n"); +#else + fprintf(stderr, " customop=/home/nihui/.cache/torch_extensions/fused/fused.so,...\n"); +#endif + fprintf(stderr, " moduleop=models.common.Focus,models.yolo.Detect,...\n"); + fprintf(stderr, "Sample usage: pnnx mobilenet_v2.pt inputshape=[1,3,224,224]\n"); + fprintf(stderr, " pnnx yolov5s.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] device=gpu moduleop=models.common.Focus,models.yolo.Detect\n"); +} + +int main(int argc, char** argv) +{ + if (argc < 2) + { + show_usage(); + return -1; + } + + for (int i = 1; i < argc; i++) + { + if (argv[i][0] == '-') + { + show_usage(); + return -1; + } + } + + std::string ptpath = std::string(argv[1]); + + std::string ptbase = get_basename(ptpath); + + std::string pnnxparampath = ptbase + ".pnnx.param"; + std::string pnnxbinpath = ptbase + ".pnnx.bin"; + std::string pnnxpypath = ptbase + "_pnnx.py"; + std::string ncnnparampath = ptbase + ".ncnn.param"; + std::string ncnnbinpath = ptbase + ".ncnn.bin"; + std::string ncnnpypath = ptbase + "_ncnn.py"; + int optlevel = 2; + std::string device = "cpu"; + std::vector > input_shapes; + std::vector > input_shapes2; + std::vector customop_modules; + std::vector module_operators; + + for (int i = 2; i < argc; i++) + { + // key=value + char* kv = argv[i]; + + char* eqs = strchr(kv, '='); + if (eqs == NULL) + { + fprintf(stderr, "unrecognized arg %s\n", kv); + continue; + } + + // split k v + eqs[0] = '\0'; + const char* key = kv; + char* value = eqs + 1; + + if (strcmp(key, "pnnxparam") == 0) + pnnxparampath = std::string(value); + if (strcmp(key, "pnnxbin") == 0) + pnnxbinpath = std::string(value); + if (strcmp(key, "pnnxpy") == 0) + pnnxpypath = std::string(value); + if (strcmp(key, "ncnnparam") == 0) + ncnnparampath = std::string(value); + if (strcmp(key, "ncnnbin") == 0) + ncnnbinpath = std::string(value); + if (strcmp(key, "ncnnpy") == 0) + ncnnpypath = std::string(value); + if (strcmp(key, "optlevel") == 0) + optlevel = atoi(value); + if (strcmp(key, "device") == 0) + device = value; + if (strcmp(key, "inputshape") == 0) + input_shapes = parse_comma_int_array_list(value); + if (strcmp(key, "inputshape2") == 0) + input_shapes2 = parse_comma_int_array_list(value); + if (strcmp(key, "customop") == 0) + customop_modules = parse_comma_string_array_list(value); + if (strcmp(key, "moduleop") == 0) + module_operators = parse_comma_string_array_list(value); + } + + // print options + { + fprintf(stderr, "pnnxparam = %s\n", pnnxparampath.c_str()); + fprintf(stderr, "pnnxbin = %s\n", pnnxbinpath.c_str()); + fprintf(stderr, "pnnxpy = %s\n", pnnxpypath.c_str()); + fprintf(stderr, "ncnnparam = %s\n", ncnnparampath.c_str()); + fprintf(stderr, "ncnnbin = %s\n", ncnnbinpath.c_str()); + fprintf(stderr, "ncnnpy = %s\n", ncnnpypath.c_str()); + fprintf(stderr, "optlevel = %d\n", optlevel); + fprintf(stderr, "device = %s\n", device.c_str()); + fprintf(stderr, "inputshape = "); + print_int64_array_list(input_shapes); + fprintf(stderr, "\n"); + fprintf(stderr, "inputshape2 = "); + print_int64_array_list(input_shapes2); + fprintf(stderr, "\n"); + fprintf(stderr, "customop = "); + print_string_list(customop_modules); + fprintf(stderr, "\n"); + fprintf(stderr, "moduleop = "); + print_string_list(module_operators); + fprintf(stderr, "\n"); + } + + // at::AutoNonVariableTypeMode nonVarTypeModeGuard(true); + // torch::autograd::AutoGradMode guard(false); + + for (auto m : customop_modules) + { + fprintf(stderr, "load custom module %s\n", m.c_str()); +#if _WIN32 + HMODULE handle = LoadLibraryExA(m.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH); + if (!handle) + { + fprintf(stderr, "LoadLibraryExA %s failed %s\n", m.c_str(), GetLastError()); + } +#else + void* handle = dlopen(m.c_str(), RTLD_LAZY); + if (!handle) + { + fprintf(stderr, "dlopen %s failed %s\n", m.c_str(), dlerror()); + } +#endif + } + + std::vector input_tensors; + for (auto shape : input_shapes) + { + at::Tensor t = torch::ones(shape); + if (device == "gpu") + t = t.cuda(); + + input_tensors.push_back(t); + } + + std::vector input_tensors2; + for (auto shape : input_shapes2) + { + at::Tensor t = torch::ones(shape); + if (device == "gpu") + t = t.cuda(); + + input_tensors2.push_back(t); + } + + torch::jit::Module mod = torch::jit::load(ptpath); + + mod.eval(); + + // mod.dump(true, false, false); + // mod.dump(true, true, true); + + auto g = mod.get_method("forward").graph(); + + // g->dump(); + + fprintf(stderr, "############# pass_level0\n"); + + pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators); + + // g->dump(); + + fprintf(stderr, "############# pass_level1\n"); + + pnnx::Graph pnnx_graph; + pnnx::pass_level1(mod, g, pnnx_graph); + + // g->dump(); + + fprintf(stderr, "############# pass_level2\n"); + + pnnx::pass_level2(pnnx_graph); + + pnnx_graph.save("debug.param", "debug.bin"); + + if (optlevel >= 1) + { + fprintf(stderr, "############# pass_level3\n"); + + pnnx::pass_level3(pnnx_graph); + + fprintf(stderr, "############# pass_level4\n"); + + pnnx::pass_level4(pnnx_graph); + } + + pnnx_graph.save("debug2.param", "debug2.bin"); + + if (optlevel >= 2) + { + fprintf(stderr, "############# pass_level5\n"); + + pnnx::pass_level5(pnnx_graph); + } + + pnnx_graph.save(pnnxparampath, pnnxbinpath); + + pnnx_graph.python(pnnxpypath, pnnxbinpath); + + // if (optlevel >= 2) + { + fprintf(stderr, "############# pass_ncnn\n"); + + pnnx::pass_ncnn(pnnx_graph); + + pnnx_graph.ncnn(ncnnparampath, ncnnbinpath, ncnnpypath); + } + + // pnnx::Graph pnnx_graph2; + + // pnnx_graph2.load("pnnx.param", "pnnx.bin"); + // pnnx_graph2.save("pnnx2.param", "pnnx2.bin"); + + return 0; +} diff --git a/tools/pnnx/src/pass_level0.cpp b/tools/pnnx/src/pass_level0.cpp new file mode 100644 index 000000000000..d098423979f6 --- /dev/null +++ b/tools/pnnx/src/pass_level0.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level0.h" + +#include "pass_level0/constant_unpooling.h" +#include "pass_level0/inline_block.h" +#include "pass_level0/shape_inference.h" + +namespace pnnx { + +void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators) +{ + inline_block(g, module_operators); + + constant_unpooling(g); + + if (!input_tensors.empty()) + { + shape_inference(mod, g, input_tensors, input_tensors2); + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0.h b/tools/pnnx/src/pass_level0.h new file mode 100644 index 000000000000..91da907042f5 --- /dev/null +++ b/tools/pnnx/src/pass_level0.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_PASS_LEVEL0_H +#define PNNX_PASS_LEVEL0_H + +#include + +namespace pnnx { + +void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators); + +} // namespace pnnx + +#endif // PNNX_PASS_LEVEL0_H diff --git a/tools/pnnx/src/pass_level0/constant_unpooling.cpp b/tools/pnnx/src/pass_level0/constant_unpooling.cpp new file mode 100644 index 000000000000..036196e013c1 --- /dev/null +++ b/tools/pnnx/src/pass_level0/constant_unpooling.cpp @@ -0,0 +1,80 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "constant_unpooling.h" + +#include +#include + +namespace pnnx { + +void ConstantUnpooling(std::shared_ptr& graph, torch::jit::Block* block, std::unordered_set& constants) +{ + for (auto it = block->nodes().begin(); it != block->nodes().end();) + { + auto node = *it; + // node may be moved to a different block so advance iterator now + ++it; + + if (!node->blocks().empty()) + { + // Traverse sub-blocks. + for (auto block : node->blocks()) + { + ConstantUnpooling(graph, block, constants); + } + + continue; + } + + for (int i = 0; i < (int)node->inputs().size(); i++) + { + const auto& in = node->input(i); + + if (in->node()->kind() != c10::prim::Constant) + continue; + + // input constant node + if (constants.find(in->node()) == constants.end()) + { + constants.insert(in->node()); + continue; + } + + torch::jit::WithInsertPoint guard(node); + + std::unordered_map value_map; + auto value_map_func = [&](torch::jit::Value* v) { + return value_map.at(v); + }; + + // graph->setInsertPoint(node); + + auto* new_constant_node = graph->insertNode(graph->createClone(in->node(), value_map_func, false)); + + // fprintf(stderr, "new_constant_node %s\n", new_constant_node->outputs()[0]->debugName().c_str()); + + // create new constant node + node->replaceInput(i, new_constant_node->outputs()[0]); + } + } +} + +void constant_unpooling(std::shared_ptr& graph) +{ + std::unordered_set constants; + ConstantUnpooling(graph, graph->block(), constants); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/constant_unpooling.h b/tools/pnnx/src/pass_level0/constant_unpooling.h new file mode 100644 index 000000000000..e6b8ec869808 --- /dev/null +++ b/tools/pnnx/src/pass_level0/constant_unpooling.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include + +namespace pnnx { + +void constant_unpooling(std::shared_ptr& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/inline_block.cpp b/tools/pnnx/src/pass_level0/inline_block.cpp new file mode 100644 index 000000000000..ecb3fc3d9f5f --- /dev/null +++ b/tools/pnnx/src/pass_level0/inline_block.cpp @@ -0,0 +1,129 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "inline_block.h" +#include "../pass_level1.h" + +#include + +#include + +namespace pnnx { + +static void inlineCallTo(torch::jit::Node* to_replace, torch::jit::Function* callee) +{ + torch::jit::WithInsertPoint guard(to_replace); + + std::unordered_map value_map; + std::vector new_outputs = torch::jit::insertGraph(*to_replace->owningGraph(), *(callee->graph()), to_replace->inputs(), value_map); + + const auto& old_outputs = to_replace->outputs(); + for (size_t i = 0; i < old_outputs.size(); ++i) + { + new_outputs[i]->copyMetadata(old_outputs[i]); + + old_outputs[i]->replaceAllUsesWith(new_outputs[i]); + } + + to_replace->destroy(); +} + +static void inlineCalls(torch::jit::Block* block, const std::vector& module_operators, std::set& inlined_modules) +{ + for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;) + { + torch::jit::Node* n = *it++; + if (n->kind() == c10::prim::CallFunction) + { + auto function_constant = n->input(0)->node(); + auto fun_type = function_constant->output()->type()->expect(); + if (!fun_type->function()->isGraphFunction()) + continue; + + inlineCalls(fun_type->function()->graph()->block(), module_operators, inlined_modules); + + n->removeInput(0); + + fprintf(stderr, "inline funtion %s\n", fun_type->function()->name().c_str()); + + pnnx::inlineCallTo(n, fun_type->function()); + } + else if (n->kind() == c10::prim::CallMethod) + { + auto class_type = n->input(0)->type()->cast(); + if (!class_type) + continue; + + const std::string& function_name = n->s(torch::jit::attr::name); + torch::jit::Function& function = class_type->getMethod(function_name); + if (!function.isGraphFunction()) + continue; + + std::string class_type_str = torch::jit::removeTorchMangle(class_type->str()); + + bool skip_inline = false; + for (const auto& ow : get_global_pnnx_fuse_module_passes()) + { + if (class_type_str == ow->match_type_str()) + { + skip_inline = true; + break; + } + } + + if (skip_inline) + continue; + + std::string class_type_str_no_torch_prefix = class_type_str.substr(10); + + if (std::find(module_operators.begin(), module_operators.end(), class_type_str_no_torch_prefix) != module_operators.end()) + { + continue; + } + + inlineCalls(function.graph()->block(), module_operators, inlined_modules); + + inlined_modules.insert(class_type_str_no_torch_prefix); + + // fprintf(stderr, "inline %s\n", class_type_str_no_torch_prefix.c_str()); + // fprintf(stderr, "inline method %s %s %s\n", function.name().c_str(), class_type->str().c_str(), n->input(0)->node()->s(torch::jit::attr::name).c_str()); + + pnnx::inlineCallTo(n, &function); + } + else + { + for (auto b : n->blocks()) + { + inlineCalls(b, module_operators, inlined_modules); + } + } + } +} + +void inline_block(std::shared_ptr& graph, const std::vector& module_operators) +{ + std::set inlined_modules; + + inlineCalls(graph->block(), module_operators, inlined_modules); + + for (const auto& x : inlined_modules) + { + if (x == "torch.nn.modules.container.Sequential") + continue; + + fprintf(stderr, "inline module = %s\n", x.c_str()); + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/inline_block.h b/tools/pnnx/src/pass_level0/inline_block.h new file mode 100644 index 000000000000..40f6a23f56fd --- /dev/null +++ b/tools/pnnx/src/pass_level0/inline_block.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include + +namespace pnnx { + +void inline_block(std::shared_ptr& graph, const std::vector& module_operators); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp new file mode 100644 index 000000000000..326f87995847 --- /dev/null +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -0,0 +1,138 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "shape_inference.h" + +namespace pnnx { + +void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2) +{ + // collect all intermediate output tensors + std::vector values; + for (const auto& n : graph->nodes()) + { + for (const auto& on : n->outputs()) + { + auto tensor_type = on->type()->cast(); + if (!tensor_type) + continue; + + values.push_back(on); + } + } + + // set new graph output + auto old_output = graph->outputs()[0]; + + torch::jit::Node* new_return_node = graph->createTuple(at::ArrayRef(values)); + + graph->appendNode(new_return_node); + + graph->eraseOutput(0); + graph->registerOutput(new_return_node->outputs()[0]); + + // inference for all tensors + std::vector inputs; + for (size_t i = 0; i < input_tensors.size(); i++) + { + const at::Tensor& it = input_tensors[i]; + + inputs.push_back(it); + graph->inputs()[1 + i]->setType(c10::TensorType::create(it)); + } + + auto outputs = mod.copy().forward(inputs).toTuple(); + + if (input_tensors2.empty()) + { + // assign shape info + int index = 0; + for (auto e : outputs->elements()) + { + values[index]->setType(c10::TensorType::create(e.toTensor())); + + index++; + } + } + else + { + std::vector inputs2; + for (size_t i = 0; i < input_tensors2.size(); i++) + { + const at::Tensor& it = input_tensors2[i]; + + inputs2.push_back(it); + graph->inputs()[1 + i]->setType(c10::TensorType::create(it)); + } + + auto outputs2 = mod.copy().forward(inputs2).toTuple(); + + fprintf(stderr, "assign dynamic shape info\n"); + + // assign dynamic shape info + for (size_t i = 0; i < input_tensors.size(); i++) + { + auto type1 = c10::TensorType::create(input_tensors[i]); + auto type2 = c10::TensorType::create(input_tensors2[i]); + + std::vector sizes1 = type1->symbolic_sizes().sizes().value(); + std::vector sizes2 = type2->symbolic_sizes().sizes().value(); + + for (size_t i = 0; i < sizes1.size(); i++) + { + if (sizes1[i] == sizes2[i]) + continue; + + sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); + } + + auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); + + graph->inputs()[1 + i]->setType(finaltype); + } + + int index = 0; + for (auto e : outputs->elements()) + { + auto type1 = c10::TensorType::create(e.toTensor()); + auto type2 = c10::TensorType::create(outputs2->elements()[index].toTensor()); + + std::vector sizes1 = type1->symbolic_sizes().sizes().value(); + std::vector sizes2 = type2->symbolic_sizes().sizes().value(); + + for (size_t i = 0; i < sizes1.size(); i++) + { + if (sizes1[i] == sizes2[i]) + continue; + + sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); + } + + auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); + + values[index]->setType(finaltype); + + index++; + } + } + + // restore old graph output + graph->eraseOutput(0); + graph->registerOutput(old_output); + + new_return_node->removeAllInputs(); + new_return_node->destroy(); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/shape_inference.h b/tools/pnnx/src/pass_level0/shape_inference.h new file mode 100644 index 000000000000..b1feba80164a --- /dev/null +++ b/tools/pnnx/src/pass_level0/shape_inference.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include + +namespace pnnx { + +void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1.cpp b/tools/pnnx/src/pass_level1.cpp new file mode 100644 index 000000000000..7fcdd8e709e5 --- /dev/null +++ b/tools/pnnx/src/pass_level1.cpp @@ -0,0 +1,302 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include + +#include "pass_level1.h" + +namespace pnnx { + +FuseModulePass::~FuseModulePass() +{ +} + +void FuseModulePass::write(Operator* /*op*/, const std::shared_ptr& /*graph*/) const +{ +} + +void FuseModulePass::write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& /*mod*/) const +{ + write(op, graph); +} + +static std::vector g_global_pnnx_fuse_module_passes; + +const std::vector& get_global_pnnx_fuse_module_passes() +{ + return g_global_pnnx_fuse_module_passes; +} + +FuseModulePassRegister::FuseModulePassRegister(const FuseModulePass* _pass) + : pass(_pass) +{ + g_global_pnnx_fuse_module_passes.push_back(pass); +} + +FuseModulePassRegister::~FuseModulePassRegister() +{ + delete pass; +} + +void pass_level1(const torch::jit::Module& mod, const std::shared_ptr& g, Graph& pg) +{ + for (int i = 1; i < (int)g->inputs().size(); i++) + { + const auto& in = g->inputs()[i]; + + char name[32]; + sprintf(name, "pnnx_input_%d", i - 1); + + Operator* op = pg.new_operator("pnnx.Input", name); + Operand* r = pg.new_operand(in); + r->producer = op; + op->outputs.push_back(r); + } + + std::map class_type_to_names; + int pnnx_unknown_index = 0; + + for (const auto& n : g->block()->nodes()) + { + if (n->kind() == c10::prim::GetAttr) + { + // pass + std::string name = n->s(torch::jit::attr::name); + // std::string name = n->debugName(); + + auto class_type = n->output(0)->type()->cast(); + + if (class_type) + { + std::string class_type_str = class_type->str(); + class_type_to_names[class_type_str] = name; + // class_type_to_names[class_type_str] = class_type_str + "." + name; + } + else + { + // Tensor from some class + // Operator* op = pg.new_operator(n->kind().toDisplayString(), name); + Operator* op = pg.new_operator("pnnx.Attribute", name); + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = pg.new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + + std::deque module_names; // = split(n->input(0)->node()->s(torch::jit::attr::name), '.'); + { + auto np = n->input(0)->node(); + while (np->hasAttribute(torch::jit::attr::name)) + { + module_names.push_front(np->s(torch::jit::attr::name)); + np = np->input(0)->node(); + } + } + + std::string wrapped_name; + auto sub_mod = mod; + for (auto module_name : module_names) + { + if (wrapped_name.size() > 0) + wrapped_name = wrapped_name + "." + module_name; + else + wrapped_name = module_name; + sub_mod = sub_mod.attr(module_name).toModule(); + } + + op->name = wrapped_name; + + // op->params["this"] = n->input(i) + + // sub_mod.dump(true, true, true); + + op->attrs[name] = sub_mod.attr(name).toTensor(); + } + } + else if (n->kind() == c10::prim::Constant) // || n->kind() == c10::prim::ListConstruct) + { + char name[32]; + sprintf(name, "pnnx_%d", pnnx_unknown_index++); + + Operator* op = pg.new_operator(n->kind().toDisplayString(), name); + + for (int i = 0; i < (int)n->inputs().size(); i++) + { + const auto& in = n->input(i); + Operand* r = pg.get_operand(in->debugName()); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = pg.new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + + op->params["value"] = n; + + if (op->params["value"].type == 8) + { + op->type = "pnnx.Attribute"; + + op->params.erase("value"); + + op->attrs[name] = n->t(torch::jit::attr::value); + } + } + else if (n->kind() == c10::prim::CallMethod) + { + auto class_type = n->input(0)->type()->cast(); + // const std::string& name = n->s(torch::jit::attr::name); + + // fprintf(stderr, "call %s\n", class_type->str().c_str()); + + std::string name = class_type_to_names[class_type->str()]; + + std::string class_type_str = torch::jit::removeTorchMangle(class_type->str()); + + std::string optypename = class_type_str; + for (const auto& ow : get_global_pnnx_fuse_module_passes()) + { + if (class_type_str != ow->match_type_str()) + continue; + + optypename = ow->type_str(); + break; + } + + if (optypename == class_type_str) + { + optypename = class_type_str.substr(10); + } + + Operator* op = pg.new_operator(optypename, name); + + for (int i = 1; i < (int)n->inputs().size(); i++) + { + const auto& in = n->input(i); + Operand* r = pg.get_operand(in->debugName()); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = pg.new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + + for (const auto& ow : get_global_pnnx_fuse_module_passes()) + { + if (class_type_str != ow->match_type_str()) + continue; + + auto class_type = n->input(0)->type()->cast(); + torch::jit::Function& function = class_type->getMethod(n->s(torch::jit::attr::name)); + + std::deque module_names; // = split(n->input(0)->node()->s(torch::jit::attr::name), '.'); + { + auto np = n->input(0)->node(); + while (np->hasAttribute(torch::jit::attr::name)) + { + module_names.push_front(np->s(torch::jit::attr::name)); + np = np->input(0)->node(); + } + } + + std::string wrapped_name; + auto sub_mod = mod; + for (auto module_name : module_names) + { + if (wrapped_name.size() > 0) + wrapped_name = wrapped_name + "." + module_name; + else + wrapped_name = module_name; + sub_mod = sub_mod.attr(module_name).toModule(); + } + + op->name = wrapped_name; + + ow->write(op, function.graph(), sub_mod); + + break; + } + } + // else if (n->kind() == c10::prim::CallFunction) + // { + // fprintf(stderr, "function %s", n->kind().toDisplayString()); + // + // AT_ASSERT(cur->input(0)->node()->kind() == c10::prim::Constant); + // auto function_constant = cur->input(0)->node(); + // auto fun_type = function_constant->output()->type()->expect(); + // if (!fun_type->function()->isGraphFunction()) + // { + // continue; + // } + // cur->removeInput(0); + // + // fprintf(stderr, "inline funtion %s\n", fun_type->function()->name().c_str()); + // + // GRAPH_UPDATE("Inlining function '", fun_type->function()->name(), "' to ", *cur); + // GRAPH_UPDATE("Function body: ", *fun_type->function()->optimized_graph()); + // inlineCallTo(cur, fun_type->function(), false); + // break; + // } + else + { + char name[32]; + sprintf(name, "pnnx_%d", pnnx_unknown_index++); + + Operator* op = pg.new_operator(n->kind().toDisplayString(), name); + + for (int i = 0; i < (int)n->inputs().size(); i++) + { + const auto& in = n->input(i); + Operand* r = pg.get_operand(in->debugName()); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int i = 0; i < (int)n->outputs().size(); i++) + { + const auto& on = n->output(i); + Operand* r = pg.new_operand(on); + r->producer = op; + op->outputs.push_back(r); + } + } + } + + for (int i = 0; i < (int)g->outputs().size(); i++) + { + const auto& in = g->outputs()[i]; + + char name[32]; + sprintf(name, "pnnx_output_%d", i); + Operator* op = pg.new_operator("pnnx.Output", name); + Operand* r = pg.get_operand(in->debugName()); + r->consumers.push_back(op); + op->inputs.push_back(r); + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1.h b/tools/pnnx/src/pass_level1.h new file mode 100644 index 000000000000..7ab9051852b6 --- /dev/null +++ b/tools/pnnx/src/pass_level1.h @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_PASS_LEVEL1_H +#define PNNX_PASS_LEVEL1_H + +#include +#include +#include "ir.h" + +namespace pnnx { + +class FuseModulePass +{ +public: + virtual ~FuseModulePass(); + + virtual const char* match_type_str() const = 0; + + virtual const char* type_str() const = 0; + + virtual void write(Operator* op, const std::shared_ptr& graph) const; + + virtual void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const; +}; + +class FuseModulePassRegister +{ +public: + FuseModulePassRegister(const FuseModulePass* pass); + ~FuseModulePassRegister(); + const FuseModulePass* pass; +}; + +const std::vector& get_global_pnnx_fuse_module_passes(); + +#define REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CLASS) \ + static FuseModulePassRegister g_global_pnnx_fusemodulepass_##CLASS##_register(new CLASS); + +void pass_level1(const torch::jit::Module& mod, const std::shared_ptr& g, Graph& pg); + +} // namespace pnnx + +#endif // PNNX_PASS_LEVEL1_H diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp new file mode 100644 index 000000000000..b7ee5241dd18 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool1d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AdaptiveAvgPool1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AdaptiveAvgPool1d"; + } + + const char* type_str() const + { + return "nn.AdaptiveAvgPool1d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* adaptive_avg_pool1d = find_node_by_kind(graph, "aten::adaptive_avg_pool1d"); + + op->params["output_size"] = adaptive_avg_pool1d->namedInput("output_size"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AdaptiveAvgPool1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp new file mode 100644 index 000000000000..7987285a87b5 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool2d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AdaptiveAvgPool2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AdaptiveAvgPool2d"; + } + + const char* type_str() const + { + return "nn.AdaptiveAvgPool2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* adaptive_avg_pool2d = find_node_by_kind(graph, "aten::adaptive_avg_pool2d"); + + op->params["output_size"] = adaptive_avg_pool2d->namedInput("output_size"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AdaptiveAvgPool2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp new file mode 100644 index 000000000000..9f34be87724f --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveAvgPool3d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AdaptiveAvgPool3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AdaptiveAvgPool3d"; + } + + const char* type_str() const + { + return "nn.AdaptiveAvgPool3d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* adaptive_avg_pool3d = find_node_by_kind(graph, "aten::adaptive_avg_pool3d"); + + op->params["output_size"] = adaptive_avg_pool3d->namedInput("output_size"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AdaptiveAvgPool3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp new file mode 100644 index 000000000000..59f838cbed3a --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool1d.cpp @@ -0,0 +1,47 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AdaptiveMaxPool1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AdaptiveMaxPool1d"; + } + + const char* type_str() const + { + return "nn.AdaptiveMaxPool1d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + graph->dump(); + + const torch::jit::Node* adaptive_max_pool1d = find_node_by_kind(graph, "aten::adaptive_max_pool1d"); + + op->params["output_size"] = adaptive_max_pool1d->namedInput("output_size"); + op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AdaptiveMaxPool1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp new file mode 100644 index 000000000000..72286d58e941 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool2d.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AdaptiveMaxPool2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AdaptiveMaxPool2d"; + } + + const char* type_str() const + { + return "nn.AdaptiveMaxPool2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* adaptive_max_pool2d = find_node_by_kind(graph, "aten::adaptive_max_pool2d"); + + op->params["output_size"] = adaptive_max_pool2d->namedInput("output_size"); + op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AdaptiveMaxPool2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp new file mode 100644 index 000000000000..faff211316dd --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AdaptiveMaxPool3d.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AdaptiveMaxPool3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AdaptiveMaxPool3d"; + } + + const char* type_str() const + { + return "nn.AdaptiveMaxPool3d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* adaptive_max_pool3d = find_node_by_kind(graph, "aten::adaptive_max_pool3d"); + + op->params["output_size"] = adaptive_max_pool3d->namedInput("output_size"); + op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AdaptiveMaxPool3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp b/tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp new file mode 100644 index 000000000000..99eb692afac8 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AvgPool1d.cpp @@ -0,0 +1,48 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AvgPool1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AvgPool1d"; + } + + const char* type_str() const + { + return "nn.AvgPool1d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* avg_pool1d = find_node_by_kind(graph, "aten::avg_pool1d"); + + op->params["kernel_size"] = avg_pool1d->namedInput("kernel_size"); + op->params["stride"] = avg_pool1d->namedInput("stride"); + op->params["padding"] = avg_pool1d->namedInput("padding"); + op->params["ceil_mode"] = avg_pool1d->namedInput("ceil_mode"); + op->params["count_include_pad"] = avg_pool1d->namedInput("count_include_pad"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AvgPool1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp b/tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp new file mode 100644 index 000000000000..bc75cee2f127 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AvgPool2d.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AvgPool2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AvgPool2d"; + } + + const char* type_str() const + { + return "nn.AvgPool2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* avg_pool2d = find_node_by_kind(graph, "aten::avg_pool2d"); + + op->params["kernel_size"] = avg_pool2d->namedInput("kernel_size"); + op->params["stride"] = avg_pool2d->namedInput("stride"); + op->params["padding"] = avg_pool2d->namedInput("padding"); + op->params["ceil_mode"] = avg_pool2d->namedInput("ceil_mode"); + op->params["count_include_pad"] = avg_pool2d->namedInput("count_include_pad"); + op->params["divisor_override"] = avg_pool2d->namedInput("divisor_override"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AvgPool2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp b/tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp new file mode 100644 index 000000000000..0b6c10fd650a --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_AvgPool3d.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class AvgPool3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.AvgPool3d"; + } + + const char* type_str() const + { + return "nn.AvgPool3d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* avg_pool3d = find_node_by_kind(graph, "aten::avg_pool3d"); + + op->params["kernel_size"] = avg_pool3d->namedInput("kernel_size"); + op->params["stride"] = avg_pool3d->namedInput("stride"); + op->params["padding"] = avg_pool3d->namedInput("padding"); + op->params["ceil_mode"] = avg_pool3d->namedInput("ceil_mode"); + op->params["count_include_pad"] = avg_pool3d->namedInput("count_include_pad"); + op->params["divisor_override"] = avg_pool3d->namedInput("divisor_override"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(AvgPool3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp b/tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp new file mode 100644 index 000000000000..afe10649e905 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_BatchNorm1d.cpp @@ -0,0 +1,57 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class BatchNorm1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.batchnorm.BatchNorm1d"; + } + + const char* type_str() const + { + return "nn.BatchNorm1d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); + + const auto& running_mean = mod.attr("running_mean").toTensor(); + const auto& running_var = mod.attr("running_var").toTensor(); + + op->params["num_features"] = running_mean.size(0); + op->params["eps"] = bn->namedInput("eps"); + op->params["affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + if (mod.hasattr("weight") && mod.hasattr("bias")) + { + op->attrs["weight"] = mod.attr("weight").toTensor(); + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(BatchNorm1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp b/tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp new file mode 100644 index 000000000000..7642c0331b0e --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_BatchNorm2d.cpp @@ -0,0 +1,57 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class BatchNorm2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.batchnorm.BatchNorm2d"; + } + + const char* type_str() const + { + return "nn.BatchNorm2d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); + + const auto& running_mean = mod.attr("running_mean").toTensor(); + const auto& running_var = mod.attr("running_var").toTensor(); + + op->params["num_features"] = running_mean.size(0); + op->params["eps"] = bn->namedInput("eps"); + op->params["affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + if (mod.hasattr("weight") && mod.hasattr("bias")) + { + op->attrs["weight"] = mod.attr("weight").toTensor(); + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(BatchNorm2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp b/tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp new file mode 100644 index 000000000000..20a832fc3a26 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_BatchNorm3d.cpp @@ -0,0 +1,57 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class BatchNorm3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.batchnorm.BatchNorm3d"; + } + + const char* type_str() const + { + return "nn.BatchNorm3d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); + + const auto& running_mean = mod.attr("running_mean").toTensor(); + const auto& running_var = mod.attr("running_var").toTensor(); + + op->params["num_features"] = running_mean.size(0); + op->params["eps"] = bn->namedInput("eps"); + op->params["affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = running_var; + if (mod.hasattr("weight") && mod.hasattr("bias")) + { + op->attrs["weight"] = mod.attr("weight").toTensor(); + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(BatchNorm3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_CELU.cpp b/tools/pnnx/src/pass_level1/nn_CELU.cpp new file mode 100644 index 000000000000..dc50b92f5093 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_CELU.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class CELU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.CELU"; + } + + const char* type_str() const + { + return "nn.CELU"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* celu = find_node_by_kind(graph, "aten::celu"); + + op->params["alpha"] = celu->namedInput("alpha"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CELU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp b/tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp new file mode 100644 index 000000000000..84ecf8410982 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ChannelShuffle.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ChannelShuffle : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.channelshuffle.ChannelShuffle"; + } + + const char* type_str() const + { + return "nn.ChannelShuffle"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* channel_shuffle = find_node_by_kind(graph, "aten::channel_shuffle"); + + op->params["groups"] = channel_shuffle->namedInput("groups"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ChannelShuffle) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp b/tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp new file mode 100644 index 000000000000..8cd639ad4bf5 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ConstantPad1d.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ConstantPad1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ConstantPad1d"; + } + + const char* type_str() const + { + return "nn.ConstantPad1d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); + + op->params["padding"] = constant_pad_nd->namedInput("pad"); + op->params["value"] = constant_pad_nd->namedInput("value"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ConstantPad1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp b/tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp new file mode 100644 index 000000000000..f987e0bef17e --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ConstantPad2d.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ConstantPad2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ConstantPad2d"; + } + + const char* type_str() const + { + return "nn.ConstantPad2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); + + op->params["padding"] = constant_pad_nd->namedInput("pad"); + op->params["value"] = constant_pad_nd->namedInput("value"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ConstantPad2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp b/tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp new file mode 100644 index 000000000000..8abf5a92545b --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ConstantPad3d.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ConstantPad3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ConstantPad3d"; + } + + const char* type_str() const + { + return "nn.ConstantPad3d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); + + op->params["padding"] = constant_pad_nd->namedInput("pad"); + op->params["value"] = constant_pad_nd->namedInput("value"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ConstantPad3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Conv1d.cpp b/tools/pnnx/src/pass_level1/nn_Conv1d.cpp new file mode 100644 index 000000000000..313226740c4e --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Conv1d.cpp @@ -0,0 +1,120 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +// #include "../pass_level3/fuse_expression.h" + +#include "../utils.h" + +namespace pnnx { + +class Conv1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.conv.Conv1d"; + } + + const char* type_str() const + { + return "nn.Conv1d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // { + // pnnx::Graph pnnx_graph; + // + // pnnx_graph.load(mod, graph); + // + // pnnx::fuse_expression(pnnx_graph); + // + // pnnx_graph.save("tmp.param", "tmp.bin"); + // } + + const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); + const torch::jit::Node* reflection_pad1d = find_node_by_kind(graph, "aten::reflection_pad1d"); + const torch::jit::Node* replication_pad1d = find_node_by_kind(graph, "aten::replication_pad1d"); + + if (convolution_mode) + { + convolution = convolution_mode; + } + + const auto& weight = mod.attr("weight").toTensor(); + + op->params["groups"] = convolution->namedInput("groups"); + op->params["in_channels"] = weight.size(1) * op->params["groups"].i; + op->params["out_channels"] = weight.size(0); + op->params["kernel_size"] = Parameter{weight.size(2)}; + op->params["stride"] = convolution->namedInput("stride"); + if (reflection_pad1d) + { + op->params["padding_mode"] = "reflect"; + op->params["padding"] = reflection_pad1d->namedInput("padding"); + std::vector& padding = op->params["padding"].ai; + if (padding.size() == 2) + { + // Conv1d only accepts tuple of one integer + if (padding[0] == padding[1]) + { + padding.resize(1); + } + else if (padding[0] != padding[1]) + { + padding.resize(0); + op->params["padding"].s = "same"; + } + } + } + else if (replication_pad1d) + { + op->params["padding_mode"] = "replicate"; + op->params["padding"] = replication_pad1d->namedInput("padding"); + std::vector& padding = op->params["padding"].ai; + if (padding.size() == 2) + { + // Conv1d only accepts tuple of one integer + if (padding[0] == padding[1]) + { + padding.resize(1); + } + else if (padding[0] != padding[1]) + { + padding.resize(0); + op->params["padding"].s = "same"; + } + } + } + else + { + op->params["padding"] = convolution->namedInput("padding"); + } + op->params["dilation"] = convolution->namedInput("dilation"); + op->params["bias"] = mod.hasattr("bias"); + + op->attrs["weight"] = weight; + if (mod.hasattr("bias")) + { + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Conv1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Conv2d.cpp b/tools/pnnx/src/pass_level1/nn_Conv2d.cpp new file mode 100644 index 000000000000..82dfe24df22c --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Conv2d.cpp @@ -0,0 +1,120 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../pass_level3/fuse_expression.h" + +#include "../utils.h" + +namespace pnnx { + +class Conv2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.conv.Conv2d"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // { + // pnnx::Graph pnnx_graph; + // + // pnnx_graph.load(mod, graph); + // + // pnnx::fuse_expression(pnnx_graph); + // + // pnnx_graph.save("tmp.param", "tmp.bin"); + // } + + const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); + const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d"); + const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d"); + + if (convolution_mode) + { + convolution = convolution_mode; + } + + const auto& weight = mod.attr("weight").toTensor(); + + op->params["groups"] = convolution->namedInput("groups"); + op->params["in_channels"] = weight.size(1) * op->params["groups"].i; + op->params["out_channels"] = weight.size(0); + op->params["kernel_size"] = Parameter{weight.size(2), weight.size(3)}; + op->params["stride"] = convolution->namedInput("stride"); + if (reflection_pad2d) + { + op->params["padding_mode"] = "reflect"; + op->params["padding"] = reflection_pad2d->namedInput("padding"); + std::vector& padding = op->params["padding"].ai; + if (padding.size() == 4) + { + // Conv2d only accepts tuple of two integers + if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3]) + { + padding.resize(2); + } + else if (padding[0] == padding[2] && padding[1] == padding[3] && padding[0] != padding[1]) + { + padding.resize(0); + op->params["padding"].s = "same"; + } + } + } + else if (replication_pad2d) + { + op->params["padding_mode"] = "replicate"; + op->params["padding"] = replication_pad2d->namedInput("padding"); + std::vector& padding = op->params["padding"].ai; + if (padding.size() == 4) + { + // Conv2d only accepts tuple of two integers + if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3]) + { + padding.resize(2); + } + else if (padding[0] == padding[2] && padding[1] == padding[3] && padding[0] != padding[1]) + { + padding.resize(0); + op->params["padding"].s = "same"; + } + } + } + else + { + op->params["padding"] = convolution->namedInput("padding"); + } + op->params["dilation"] = convolution->namedInput("dilation"); + op->params["bias"] = mod.hasattr("bias"); + + op->attrs["weight"] = weight; + if (mod.hasattr("bias")) + { + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Conv2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp new file mode 100644 index 000000000000..ca48f9bdf7bf --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp @@ -0,0 +1,120 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../pass_level3/fuse_expression.h" + +#include "../utils.h" + +namespace pnnx { + +class Conv3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.conv.Conv3d"; + } + + const char* type_str() const + { + return "nn.Conv3d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // { + // pnnx::Graph pnnx_graph; + // + // pnnx_graph.load(mod, graph); + // + // pnnx::fuse_expression(pnnx_graph); + // + // pnnx_graph.save("tmp.param", "tmp.bin"); + // } + + const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); + // const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d"); + // const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); + + if (convolution_mode) + { + convolution = convolution_mode; + } + + const auto& weight = mod.attr("weight").toTensor(); + + op->params["groups"] = convolution->namedInput("groups"); + op->params["in_channels"] = weight.size(1) * op->params["groups"].i; + op->params["out_channels"] = weight.size(0); + op->params["kernel_size"] = Parameter{weight.size(2), weight.size(3), weight.size(4)}; + op->params["stride"] = convolution->namedInput("stride"); + // if (reflection_pad3d) + // { + // op->params["padding_mode"] = "reflect"; + // op->params["padding"] = reflection_pad3d->namedInput("padding"); + // std::vector& padding = op->params["padding"].ai; + // if (padding.size() == 6) + // { + // // Conv3d only accepts tuple of three integers + // if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5]) + // { + // padding.resize(3); + // } + // else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2]) + // { + // padding.resize(0); + // op->params["padding"].s = "same"; + // } + // } + // } + // else if (replication_pad3d) + // { + // op->params["padding_mode"] = "replicate"; + // op->params["padding"] = replication_pad3d->namedInput("padding"); + // std::vector& padding = op->params["padding"].ai; + // if (padding.size() == 6) + // { + // // Conv3d only accepts tuple of three integers + // if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5]) + // { + // padding.resize(3); + // } + // else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2]) + // { + // padding.resize(0); + // op->params["padding"].s = "same"; + // } + // } + // } + // else + { + op->params["padding"] = convolution->namedInput("padding"); + } + op->params["dilation"] = convolution->namedInput("dilation"); + op->params["bias"] = mod.hasattr("bias"); + + op->attrs["weight"] = weight; + if (mod.hasattr("bias")) + { + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Conv3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp new file mode 100644 index 000000000000..c6f2ce9b430a --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp @@ -0,0 +1,60 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ConvTranspose1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.conv.ConvTranspose1d"; + } + + const char* type_str() const + { + return "nn.ConvTranspose1d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + + const auto& weight = mod.attr("weight").toTensor(); + + op->params["groups"] = convolution->namedInput("groups"); + op->params["in_channels"] = weight.size(0); + op->params["out_channels"] = weight.size(1) * op->params["groups"].i; + op->params["kernel_size"] = Parameter{weight.size(2)}; + op->params["stride"] = convolution->namedInput("stride"); + op->params["padding"] = convolution->namedInput("padding"); + op->params["output_padding"] = convolution->namedInput("output_padding"); + op->params["dilation"] = convolution->namedInput("dilation"); + op->params["bias"] = mod.hasattr("bias"); + + op->attrs["weight"] = weight; + if (mod.hasattr("bias")) + { + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ConvTranspose1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp new file mode 100644 index 000000000000..32b55f5d30a3 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp @@ -0,0 +1,60 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ConvTranspose2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.conv.ConvTranspose2d"; + } + + const char* type_str() const + { + return "nn.ConvTranspose2d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + + const auto& weight = mod.attr("weight").toTensor(); + + op->params["groups"] = convolution->namedInput("groups"); + op->params["in_channels"] = weight.size(0); + op->params["out_channels"] = weight.size(1) * op->params["groups"].i; + op->params["kernel_size"] = Parameter{weight.size(2), weight.size(3)}; + op->params["stride"] = convolution->namedInput("stride"); + op->params["padding"] = convolution->namedInput("padding"); + op->params["output_padding"] = convolution->namedInput("output_padding"); + op->params["dilation"] = convolution->namedInput("dilation"); + op->params["bias"] = mod.hasattr("bias"); + + op->attrs["weight"] = weight; + if (mod.hasattr("bias")) + { + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ConvTranspose2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp new file mode 100644 index 000000000000..1f414efadcf8 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp @@ -0,0 +1,60 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ConvTranspose3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.conv.ConvTranspose3d"; + } + + const char* type_str() const + { + return "nn.ConvTranspose3d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); + + const auto& weight = mod.attr("weight").toTensor(); + + op->params["groups"] = convolution->namedInput("groups"); + op->params["in_channels"] = weight.size(0); + op->params["out_channels"] = weight.size(1) * op->params["groups"].i; + op->params["kernel_size"] = Parameter{weight.size(2), weight.size(3), weight.size(4)}; + op->params["stride"] = convolution->namedInput("stride"); + op->params["padding"] = convolution->namedInput("padding"); + op->params["output_padding"] = convolution->namedInput("output_padding"); + op->params["dilation"] = convolution->namedInput("dilation"); + op->params["bias"] = mod.hasattr("bias"); + + op->attrs["weight"] = weight; + if (mod.hasattr("bias")) + { + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ConvTranspose3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Dropout.cpp b/tools/pnnx/src/pass_level1/nn_Dropout.cpp new file mode 100644 index 000000000000..1c18dd4530de --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Dropout.cpp @@ -0,0 +1,37 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Dropout : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.dropout.Dropout"; + } + + const char* type_str() const + { + return "nn.Dropout"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Dropout) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ELU.cpp b/tools/pnnx/src/pass_level1/nn_ELU.cpp new file mode 100644 index 000000000000..a5b309ee3921 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ELU.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ELU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.ELU"; + } + + const char* type_str() const + { + return "nn.ELU"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* elu = find_node_by_kind(graph, "aten::elu"); + + op->params["alpha"] = elu->namedInput("alpha"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ELU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Embedding.cpp b/tools/pnnx/src/pass_level1/nn_Embedding.cpp new file mode 100644 index 000000000000..8e1c76befc55 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Embedding.cpp @@ -0,0 +1,53 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Embedding : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.sparse.Embedding"; + } + + const char* type_str() const + { + return "nn.Embedding"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* embedding = find_node_by_kind(graph, "aten::embedding"); + + const auto& weight = mod.attr("weight").toTensor(); + + op->params["num_embeddings"] = weight.size(0); + op->params["embedding_dim"] = weight.size(1); + + // op->params["padding_idx"] = embedding->namedInput("padding_idx"); + // op->params["scale_grad_by_freq"] = embedding->namedInput("scale_grad_by_freq"); + op->params["sparse"] = embedding->namedInput("sparse"); + + op->attrs["weight"] = weight; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Embedding) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_GELU.cpp b/tools/pnnx/src/pass_level1/nn_GELU.cpp new file mode 100644 index 000000000000..6127ae23af11 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_GELU.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class GELU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.GELU"; + } + + const char* type_str() const + { + return "nn.GELU"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(GELU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_GRU.cpp b/tools/pnnx/src/pass_level1/nn_GRU.cpp new file mode 100644 index 000000000000..8ed972c84e4a --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_GRU.cpp @@ -0,0 +1,110 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class GRU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.rnn.GRU"; + } + + const char* type_str() const + { + return "nn.GRU"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // mod.dump(true, true, true); + + // graph->dump(); + + const torch::jit::Node* gru = find_node_by_kind(graph, "aten::gru"); + + const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); + if (return_tuple && return_tuple->inputs().size() == 2 && gru->outputs().size() == 2 + && return_tuple->inputs()[0] == gru->outputs()[1] && return_tuple->inputs()[1] == gru->outputs()[0]) + { + // mark the swapped output tuple + // we would restore the fine order in pass_level3/fuse_rnn_unpack + fprintf(stderr, "swapped detected !\n"); + op->params["pnnx_rnn_output_swapped"] = 1; + } + + // for (auto aa : gru->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); + + op->params["input_size"] = weight_ih_l0.size(1); + op->params["hidden_size"] = weight_ih_l0.size(0) / 3; + op->params["num_layers"] = gru->namedInput("num_layers"); + op->params["bias"] = gru->namedInput("has_biases"); + op->params["batch_first"] = gru->namedInput("batch_first"); + op->params["bidirectional"] = gru->namedInput("bidirectional"); + + const int num_layers = op->params["num_layers"].i; + const bool bias = op->params["bias"].b; + const bool bidirectional = op->params["bidirectional"].b; + + for (int k = 0; k < num_layers; k++) + { + std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); + std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); + + op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); + op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); + + if (bias) + { + std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); + std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); + + op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); + op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); + } + + if (bidirectional) + { + std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; + std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; + + op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); + op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); + + if (bias) + { + std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; + std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; + + op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); + op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); + } + } + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(GRU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_GroupNorm.cpp b/tools/pnnx/src/pass_level1/nn_GroupNorm.cpp new file mode 100644 index 000000000000..40222bdfcacf --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_GroupNorm.cpp @@ -0,0 +1,67 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class GroupNorm : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.normalization.GroupNorm"; + } + + const char* type_str() const + { + return "nn.GroupNorm"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // graph->dump(); + + const torch::jit::Node* gn = find_node_by_kind(graph, "aten::group_norm"); + + // for (auto aa : gn->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + op->params["num_groups"] = gn->namedInput("num_groups"); + op->params["eps"] = gn->namedInput("eps"); + op->params["affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + + if (mod.hasattr("weight") && mod.hasattr("bias")) + { + const auto& weight = mod.attr("weight").toTensor(); + + op->params["num_channels"] = weight.size(0); + + op->attrs["weight"] = weight; + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + else + { + fprintf(stderr, "Cannot resolve GroupNorm num_channels when affint=False\n"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(GroupNorm) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Hardshrink.cpp b/tools/pnnx/src/pass_level1/nn_Hardshrink.cpp new file mode 100644 index 000000000000..ee230a9fac2e --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Hardshrink.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Hardshrink : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Hardshrink"; + } + + const char* type_str() const + { + return "nn.Hardshrink"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* hardshrink = find_node_by_kind(graph, "aten::hardshrink"); + + op->params["lambd"] = hardshrink->namedInput("lambd"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Hardshrink) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp b/tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp new file mode 100644 index 000000000000..ba6fe78e925f --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Hardsigmoid.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class Hardsigmoid : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Hardsigmoid"; + } + + const char* type_str() const + { + return "nn.Hardsigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Hardsigmoid) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Hardswish.cpp b/tools/pnnx/src/pass_level1/nn_Hardswish.cpp new file mode 100644 index 000000000000..1061832a10a7 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Hardswish.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class Hardswish : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Hardswish"; + } + + const char* type_str() const + { + return "nn.Hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Hardswish) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Hardtanh.cpp b/tools/pnnx/src/pass_level1/nn_Hardtanh.cpp new file mode 100644 index 000000000000..1c4ee37ab2e3 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Hardtanh.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Hardtanh : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Hardtanh"; + } + + const char* type_str() const + { + return "nn.Hardtanh"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* hardtanh = find_node_by_kind(graph, "aten::hardtanh"); + + op->params["min_val"] = hardtanh->namedInput("min_val"); + op->params["max_val"] = hardtanh->namedInput("max_val"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Hardtanh) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp b/tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp new file mode 100644 index 000000000000..8d12739d0fa6 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_InstanceNorm1d.cpp @@ -0,0 +1,73 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class InstanceNorm1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.instancenorm.InstanceNorm1d"; + } + + const char* type_str() const + { + return "nn.InstanceNorm1d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // graph->dump(); + + const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); + + // for (auto aa : in->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + op->params["eps"] = in->namedInput("eps"); + op->params["affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + op->params["track_running_stats"] = mod.hasattr("running_mean") && mod.hasattr("running_var"); + + if (mod.hasattr("weight") && mod.hasattr("bias")) + { + const auto& weight = mod.attr("weight").toTensor(); + + op->params["num_features"] = weight.size(0); + + op->attrs["weight"] = weight; + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + + if (mod.hasattr("running_mean") && mod.hasattr("running_var")) + { + const auto& running_mean = mod.attr("running_mean").toTensor(); + + op->params["num_features"] = running_mean.size(0); + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = mod.attr("running_var").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(InstanceNorm1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp b/tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp new file mode 100644 index 000000000000..b4a4f4e2b5ea --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_InstanceNorm2d.cpp @@ -0,0 +1,73 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class InstanceNorm2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.instancenorm.InstanceNorm2d"; + } + + const char* type_str() const + { + return "nn.InstanceNorm2d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // graph->dump(); + + const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); + + // for (auto aa : in->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + op->params["eps"] = in->namedInput("eps"); + op->params["affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + op->params["track_running_stats"] = mod.hasattr("running_mean") && mod.hasattr("running_var"); + + if (mod.hasattr("weight") && mod.hasattr("bias")) + { + const auto& weight = mod.attr("weight").toTensor(); + + op->params["num_features"] = weight.size(0); + + op->attrs["weight"] = weight; + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + + if (mod.hasattr("running_mean") && mod.hasattr("running_var")) + { + const auto& running_mean = mod.attr("running_mean").toTensor(); + + op->params["num_features"] = running_mean.size(0); + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = mod.attr("running_var").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(InstanceNorm2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp b/tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp new file mode 100644 index 000000000000..9906ffa3527e --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_InstanceNorm3d.cpp @@ -0,0 +1,73 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class InstanceNorm3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.instancenorm.InstanceNorm3d"; + } + + const char* type_str() const + { + return "nn.InstanceNorm3d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // graph->dump(); + + const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); + + // for (auto aa : in->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + op->params["eps"] = in->namedInput("eps"); + op->params["affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + op->params["track_running_stats"] = mod.hasattr("running_mean") && mod.hasattr("running_var"); + + if (mod.hasattr("weight") && mod.hasattr("bias")) + { + const auto& weight = mod.attr("weight").toTensor(); + + op->params["num_features"] = weight.size(0); + + op->attrs["weight"] = weight; + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + + if (mod.hasattr("running_mean") && mod.hasattr("running_var")) + { + const auto& running_mean = mod.attr("running_mean").toTensor(); + + op->params["num_features"] = running_mean.size(0); + + op->attrs["running_mean"] = running_mean; + op->attrs["running_var"] = mod.attr("running_var").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(InstanceNorm3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_LPPool1d.cpp b/tools/pnnx/src/pass_level1/nn_LPPool1d.cpp new file mode 100644 index 000000000000..f7b7375769ee --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_LPPool1d.cpp @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class LPPool1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.LPPool1d"; + } + + const char* type_str() const + { + return "nn.LPPool1d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); + op->params["norm_type"] = pow->inputs()[1]; + + const torch::jit::Node* avg_pool1d = find_node_by_kind(graph, "aten::avg_pool1d"); + + op->params["kernel_size"] = avg_pool1d->namedInput("kernel_size")->node()->inputs()[0]; + if (avg_pool1d->namedInput("stride")->node()->inputs().size() == 0) + { + op->params["stride"] = op->params["kernel_size"]; + } + else + { + op->params["stride"] = avg_pool1d->namedInput("stride")->node()->inputs()[0]; + } + op->params["ceil_mode"] = avg_pool1d->namedInput("ceil_mode"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LPPool1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_LPPool2d.cpp b/tools/pnnx/src/pass_level1/nn_LPPool2d.cpp new file mode 100644 index 000000000000..d843704f781a --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_LPPool2d.cpp @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class LPPool2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.LPPool2d"; + } + + const char* type_str() const + { + return "nn.LPPool2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); + op->params["norm_type"] = pow->inputs()[1]; + + const torch::jit::Node* avg_pool2d = find_node_by_kind(graph, "aten::avg_pool2d"); + + op->params["kernel_size"] = avg_pool2d->namedInput("kernel_size"); + if (avg_pool2d->namedInput("stride")->node()->inputs().size() == 0) + { + op->params["stride"] = op->params["kernel_size"]; + } + else + { + op->params["stride"] = avg_pool2d->namedInput("stride"); + } + op->params["ceil_mode"] = avg_pool2d->namedInput("ceil_mode"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LPPool2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_LSTM.cpp b/tools/pnnx/src/pass_level1/nn_LSTM.cpp new file mode 100644 index 000000000000..ba82232ce6c6 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_LSTM.cpp @@ -0,0 +1,110 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class LSTM : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.rnn.LSTM"; + } + + const char* type_str() const + { + return "nn.LSTM"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // mod.dump(true, true, true); + + // graph->dump(); + + const torch::jit::Node* lstm = find_node_by_kind(graph, "aten::lstm"); + + const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); + if (return_tuple && return_tuple->inputs().size() == 3 && lstm->outputs().size() == 3 + && return_tuple->inputs()[0] == lstm->outputs()[1] && return_tuple->inputs()[1] == lstm->outputs()[2] && return_tuple->inputs()[2] == lstm->outputs()[0]) + { + // mark the swapped output tuple + // we would restore the fine order in pass_level3/fuse_rnn_unpack + fprintf(stderr, "swapped detected !\n"); + op->params["pnnx_rnn_output_swapped"] = 1; + } + + // for (auto aa : lstm->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); + + op->params["input_size"] = weight_ih_l0.size(1); + op->params["hidden_size"] = weight_ih_l0.size(0) / 4; + op->params["num_layers"] = lstm->namedInput("num_layers"); + op->params["bias"] = lstm->namedInput("has_biases"); + op->params["batch_first"] = lstm->namedInput("batch_first"); + op->params["bidirectional"] = lstm->namedInput("bidirectional"); + + const int num_layers = op->params["num_layers"].i; + const bool bias = op->params["bias"].b; + const bool bidirectional = op->params["bidirectional"].b; + + for (int k = 0; k < num_layers; k++) + { + std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); + std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); + + op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); + op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); + + if (bias) + { + std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); + std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); + + op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); + op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); + } + + if (bidirectional) + { + std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; + std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; + + op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); + op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); + + if (bias) + { + std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; + std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; + + op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); + op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); + } + } + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LSTM) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_LayerNorm.cpp b/tools/pnnx/src/pass_level1/nn_LayerNorm.cpp new file mode 100644 index 000000000000..5faa8d8c02f5 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_LayerNorm.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class LayerNorm : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.normalization.LayerNorm"; + } + + const char* type_str() const + { + return "nn.LayerNorm"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* ln = find_node_by_kind(graph, "aten::layer_norm"); + + op->params["normalized_shape"] = ln->namedInput("normalized_shape"); + op->params["eps"] = ln->namedInput("eps"); + op->params["elementwise_affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + + if (mod.hasattr("weight") && mod.hasattr("bias")) + { + op->attrs["weight"] = mod.attr("weight").toTensor(); + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LayerNorm) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp b/tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp new file mode 100644 index 000000000000..689f1c66540a --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_LeakyReLU.cpp @@ -0,0 +1,50 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class LeakyReLU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.LeakyReLU"; + } + + const char* type_str() const + { + return "nn.LeakyReLU"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* leaky_relu = find_node_by_kind(graph, "aten::leaky_relu"); + const torch::jit::Node* leaky_relu_ = find_node_by_kind(graph, "aten::leaky_relu_"); + + if (leaky_relu_) + { + leaky_relu = leaky_relu_; + } + + op->params["negative_slope"] = leaky_relu->namedInput("negative_slope"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LeakyReLU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Linear.cpp b/tools/pnnx/src/pass_level1/nn_Linear.cpp new file mode 100644 index 000000000000..2b03b60c3b01 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Linear.cpp @@ -0,0 +1,54 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Linear : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.linear.Linear"; + } + + const char* type_str() const + { + return "nn.Linear"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + const torch::jit::Node* addmm = find_node_by_kind(graph, "aten::addmm"); + + const auto& weight = mod.attr("weight").toTensor(); + + op->params["in_features"] = weight.size(1); + op->params["out_features"] = weight.size(0); + op->params["bias"] = mod.hasattr("bias"); + + op->attrs["weight"] = weight; + if (mod.hasattr("bias")) + { + op->attrs["bias"] = mod.attr("bias").toTensor(); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Linear) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp b/tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp new file mode 100644 index 000000000000..1f88f78f0358 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_LocalResponseNorm.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class LocalResponseNorm : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.normalization.LocalResponseNorm"; + } + + const char* type_str() const + { + return "nn.LocalResponseNorm"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* avg_pool = find_node_by_kind(graph, "aten::avg_pool2d"); + const torch::jit::Node* avg_pool3d = find_node_by_kind(graph, "aten::avg_pool3d"); + + if (avg_pool3d) + { + avg_pool = avg_pool3d; + } + + op->params["size"] = avg_pool->namedInput("kernel_size")->node()->inputs()[0]; + + const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); + op->params["beta"] = pow->inputs()[1]; + + const torch::jit::Node* add = pow->inputs()[0]->node(); + op->params["k"] = add->inputs()[1]; + + const torch::jit::Node* mul = add->inputs()[0]->node(); + op->params["alpha"] = mul->inputs()[1]; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LocalResponseNorm) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp b/tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp new file mode 100644 index 000000000000..d0a1646048f0 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_LogSigmoid.cpp @@ -0,0 +1,37 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class LogSigmoid : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.LogSigmoid"; + } + + const char* type_str() const + { + return "nn.LogSigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LogSigmoid) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp b/tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp new file mode 100644 index 000000000000..e5ecd673cc46 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_LogSoftmax.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class LogSoftmax : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.LogSoftmax"; + } + + const char* type_str() const + { + return "nn.LogSoftmax"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* log_softmax = find_node_by_kind(graph, "aten::log_softmax"); + + op->params["dim"] = log_softmax->namedInput("dim"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(LogSoftmax) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp b/tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp new file mode 100644 index 000000000000..5f42ca5b72bb --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_MaxPool1d.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class MaxPool1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.MaxPool1d"; + } + + const char* type_str() const + { + return "nn.MaxPool1d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* max_pool1d = find_node_by_kind(graph, "aten::max_pool1d"); + const torch::jit::Node* max_pool1d_with_indices = find_node_by_kind(graph, "aten::max_pool1d_with_indices"); + + if (max_pool1d_with_indices) + { + max_pool1d = max_pool1d_with_indices; + } + + op->params["kernel_size"] = max_pool1d->namedInput("kernel_size"); + op->params["stride"] = max_pool1d->namedInput("stride"); + op->params["padding"] = max_pool1d->namedInput("padding"); + op->params["dilation"] = max_pool1d->namedInput("dilation"); + op->params["ceil_mode"] = max_pool1d->namedInput("ceil_mode"); + op->params["return_indices"] = max_pool1d_with_indices ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(MaxPool1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp b/tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp new file mode 100644 index 000000000000..8a806441552b --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_MaxPool2d.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class MaxPool2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.MaxPool2d"; + } + + const char* type_str() const + { + return "nn.MaxPool2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* max_pool2d = find_node_by_kind(graph, "aten::max_pool2d"); + const torch::jit::Node* max_pool2d_with_indices = find_node_by_kind(graph, "aten::max_pool2d_with_indices"); + + if (max_pool2d_with_indices) + { + max_pool2d = max_pool2d_with_indices; + } + + op->params["kernel_size"] = max_pool2d->namedInput("kernel_size"); + op->params["stride"] = max_pool2d->namedInput("stride"); + op->params["padding"] = max_pool2d->namedInput("padding"); + op->params["dilation"] = max_pool2d->namedInput("dilation"); + op->params["ceil_mode"] = max_pool2d->namedInput("ceil_mode"); + op->params["return_indices"] = max_pool2d_with_indices ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(MaxPool2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp b/tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp new file mode 100644 index 000000000000..e53e4a7fd616 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_MaxPool3d.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class MaxPool3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.MaxPool3d"; + } + + const char* type_str() const + { + return "nn.MaxPool3d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* max_pool3d = find_node_by_kind(graph, "aten::max_pool3d"); + const torch::jit::Node* max_pool3d_with_indices = find_node_by_kind(graph, "aten::max_pool3d_with_indices"); + + if (max_pool3d_with_indices) + { + max_pool3d = max_pool3d_with_indices; + } + + op->params["kernel_size"] = max_pool3d->namedInput("kernel_size"); + op->params["stride"] = max_pool3d->namedInput("stride"); + op->params["padding"] = max_pool3d->namedInput("padding"); + op->params["dilation"] = max_pool3d->namedInput("dilation"); + op->params["ceil_mode"] = max_pool3d->namedInput("ceil_mode"); + op->params["return_indices"] = max_pool3d_with_indices ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(MaxPool3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Mish.cpp b/tools/pnnx/src/pass_level1/nn_Mish.cpp new file mode 100644 index 000000000000..c65fb6a896ea --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Mish.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class Mish : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Mish"; + } + + const char* type_str() const + { + return "nn.Mish"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Mish) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp b/tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp new file mode 100644 index 000000000000..5a54ac442db5 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_MultiheadAttention.cpp @@ -0,0 +1,126 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include + +#include "../utils.h" + +namespace pnnx { + +class MultiheadAttention : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.MultiheadAttention"; + } + + const char* type_str() const + { + return "nn.MultiheadAttention"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // mod.dump(false, false, false); + + // graph->dump(); + + const torch::jit::Node* div_num_heads = find_node_by_kind(graph, "aten::div"); + const torch::jit::Node* div_num_heads_18 = find_node_by_kind(graph, "aten::floor_divide"); + if (div_num_heads_18) + { + div_num_heads = div_num_heads_18; + } + + op->params["num_heads"] = (int)div_num_heads->input(1)->node()->t(torch::jit::attr::value).item(); + + const torch::jit::Node* transpose_batch_seq = find_node_by_kind(graph, "aten::transpose"); + + int transpose_dim0 = transpose_batch_seq->input(1)->node()->i(torch::jit::attr::value); + int transpose_dim1 = transpose_batch_seq->input(2)->node()->i(torch::jit::attr::value); + if (transpose_dim0 == 1 && transpose_dim1 == 0) + { + op->params["batch_first"] = true; + } +#if TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 9 + else + { + op->params["batch_first"] = false; + } +#endif + + const torch::jit::Node* add_zero_attn = find_node_by_kind(graph, "aten::zeros"); + if (add_zero_attn) + { + op->params["add_zero_attn"] = true; + } + else + { + op->params["add_zero_attn"] = false; + } + + const auto& in_proj_weight = mod.attr("in_proj_weight").toTensor(); + const auto& out_proj_weight = mod.attr("out_proj").toModule().attr("weight").toTensor(); + + op->params["embed_dim"] = in_proj_weight.size(1); + op->attrs["in_proj_weight"] = in_proj_weight; + op->attrs["out_proj.weight"] = out_proj_weight; + + if (mod.hasattr("in_proj_bias") && mod.attr("out_proj").toModule().hasattr("bias")) + { + // bias=True + const auto& in_proj_bias = mod.attr("in_proj_bias").toTensor(); + const auto& out_proj_bias = mod.attr("out_proj").toModule().attr("bias").toTensor(); + + op->params["bias"] = true; + op->attrs["in_proj_bias"] = in_proj_bias; + op->attrs["out_proj.bias"] = out_proj_bias; + } + else + { + op->params["bias"] = false; +#if TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 8 + // the output projection bias always there no matter bias is False in pytorch 1.8 + // this behavior changes since https://github.com/pytorch/pytorch/commit/58d1b3639bc07f9519de18e5a18e575f260c7eeb + if (mod.attr("out_proj").toModule().hasattr("bias")) + { + const auto& out_proj_bias = mod.attr("out_proj").toModule().attr("bias").toTensor(); + op->attrs["out_proj.bias"] = out_proj_bias; + } +#endif + } + + if (mod.hasattr("bias_k") && mod.hasattr("bias_v")) + { + // add_bias_kv=True + const auto& bias_k = mod.attr("bias_k").toTensor(); + const auto& bias_v = mod.attr("bias_v").toTensor(); + + op->params["add_bias_kv"] = true; + op->attrs["bias_k"] = bias_k; + op->attrs["bias_v"] = bias_v; + } + else + { + op->params["add_bias_kv"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(MultiheadAttention) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_PReLU.cpp b/tools/pnnx/src/pass_level1/nn_PReLU.cpp new file mode 100644 index 000000000000..52b3f249760a --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_PReLU.cpp @@ -0,0 +1,46 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class PReLU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.PReLU"; + } + + const char* type_str() const + { + return "nn.PReLU"; + } + + void write(Operator* op, const std::shared_ptr& /*graph*/, const torch::jit::Module& mod) const + { + const auto& weight = mod.attr("weight").toTensor(); + + op->params["num_parameters"] = weight.size(0); + + op->attrs["weight"] = weight; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(PReLU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp b/tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp new file mode 100644 index 000000000000..f9c1bbc6b65e --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_PixelShuffle.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class PixelShuffle : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pixelshuffle.PixelShuffle"; + } + + const char* type_str() const + { + return "nn.PixelShuffle"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* pixel_shuffle = find_node_by_kind(graph, "aten::pixel_shuffle"); + + op->params["upscale_factor"] = pixel_shuffle->namedInput("upscale_factor"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(PixelShuffle) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp b/tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp new file mode 100644 index 000000000000..73154db9a5e8 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_PixelUnshuffle.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class PixelUnshuffle : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pixelshuffle.PixelUnshuffle"; + } + + const char* type_str() const + { + return "nn.PixelUnshuffle"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* pixel_unshuffle = find_node_by_kind(graph, "aten::pixel_unshuffle"); + + op->params["downscale_factor"] = pixel_unshuffle->namedInput("downscale_factor"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(PixelUnshuffle) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_RNN.cpp b/tools/pnnx/src/pass_level1/nn_RNN.cpp new file mode 100644 index 000000000000..ab64839297da --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_RNN.cpp @@ -0,0 +1,117 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class RNN : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.rnn.RNN"; + } + + const char* type_str() const + { + return "nn.RNN"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // mod.dump(true, true, true); + + // graph->dump(); + + const torch::jit::Node* rnn = find_node_by_kind(graph, "aten::rnn_tanh"); + const torch::jit::Node* rnn_relu = find_node_by_kind(graph, "aten::rnn_relu"); + + if (rnn_relu) + { + rnn = rnn_relu; + } + + const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); + if (return_tuple && return_tuple->inputs().size() == 2 && rnn->outputs().size() == 2 + && return_tuple->inputs()[0] == rnn->outputs()[1] && return_tuple->inputs()[1] == rnn->outputs()[0]) + { + // mark the swapped output tuple + // we would restore the fine order in pass_level3/fuse_rnn_unpack + fprintf(stderr, "swapped detected !\n"); + op->params["pnnx_rnn_output_swapped"] = 1; + } + + // for (auto aa : rnn->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); + + op->params["input_size"] = weight_ih_l0.size(1); + op->params["hidden_size"] = weight_ih_l0.size(0); + op->params["num_layers"] = rnn->namedInput("num_layers"); + op->params["nonlinearity"] = rnn_relu ? "relu" : "tanh"; + op->params["bias"] = rnn->namedInput("has_biases"); + op->params["batch_first"] = rnn->namedInput("batch_first"); + op->params["bidirectional"] = rnn->namedInput("bidirectional"); + + const int num_layers = op->params["num_layers"].i; + const bool bias = op->params["bias"].b; + const bool bidirectional = op->params["bidirectional"].b; + + for (int k = 0; k < num_layers; k++) + { + std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); + std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); + + op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); + op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); + + if (bias) + { + std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); + std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); + + op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); + op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); + } + + if (bidirectional) + { + std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; + std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; + + op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); + op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); + + if (bias) + { + std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; + std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; + + op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); + op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); + } + } + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(RNN) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_RReLU.cpp b/tools/pnnx/src/pass_level1/nn_RReLU.cpp new file mode 100644 index 000000000000..538552ad2da1 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_RReLU.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class RReLU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.RReLU"; + } + + const char* type_str() const + { + return "nn.RReLU"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* rrelu = find_node_by_kind(graph, "aten::rrelu"); + + op->params["lower"] = rrelu->namedInput("lower"); + op->params["upper"] = rrelu->namedInput("upper"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(RReLU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ReLU.cpp b/tools/pnnx/src/pass_level1/nn_ReLU.cpp new file mode 100644 index 000000000000..bc213c64de51 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ReLU.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class ReLU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.ReLU"; + } + + const char* type_str() const + { + return "nn.ReLU"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ReLU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ReLU6.cpp b/tools/pnnx/src/pass_level1/nn_ReLU6.cpp new file mode 100644 index 000000000000..69bc36a6fc23 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ReLU6.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class ReLU6 : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.ReLU6"; + } + + const char* type_str() const + { + return "nn.ReLU6"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ReLU6) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp b/tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp new file mode 100644 index 000000000000..e27f58fa7f1b --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ReflectionPad1d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ReflectionPad1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ReflectionPad1d"; + } + + const char* type_str() const + { + return "nn.ReflectionPad1d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* reflection_pad1d = find_node_by_kind(graph, "aten::reflection_pad1d"); + + op->params["padding"] = reflection_pad1d->namedInput("padding"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ReflectionPad1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp b/tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp new file mode 100644 index 000000000000..b2398cf195de --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ReflectionPad2d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ReflectionPad2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ReflectionPad2d"; + } + + const char* type_str() const + { + return "nn.ReflectionPad2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d"); + + op->params["padding"] = reflection_pad2d->namedInput("padding"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ReflectionPad2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp b/tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp new file mode 100644 index 000000000000..2d849fbebae6 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ReplicationPad1d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ReplicationPad1d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ReplicationPad1d"; + } + + const char* type_str() const + { + return "nn.ReplicationPad1d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* replication_pad1d = find_node_by_kind(graph, "aten::replication_pad1d"); + + op->params["padding"] = replication_pad1d->namedInput("padding"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ReplicationPad1d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp b/tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp new file mode 100644 index 000000000000..d666d325a217 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ReplicationPad2d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ReplicationPad2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ReplicationPad2d"; + } + + const char* type_str() const + { + return "nn.ReplicationPad2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d"); + + op->params["padding"] = replication_pad2d->namedInput("padding"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ReplicationPad2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp b/tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp new file mode 100644 index 000000000000..c706a375d006 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ReplicationPad3d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ReplicationPad3d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ReplicationPad3d"; + } + + const char* type_str() const + { + return "nn.ReplicationPad3d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); + + op->params["padding"] = replication_pad3d->namedInput("padding"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ReplicationPad3d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_SELU.cpp b/tools/pnnx/src/pass_level1/nn_SELU.cpp new file mode 100644 index 000000000000..c6ca3d472331 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_SELU.cpp @@ -0,0 +1,37 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class SELU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.SELU"; + } + + const char* type_str() const + { + return "nn.SELU"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(SELU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_SiLU.cpp b/tools/pnnx/src/pass_level1/nn_SiLU.cpp new file mode 100644 index 000000000000..816e9aa77669 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_SiLU.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class SiLU : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.SiLU"; + } + + const char* type_str() const + { + return "nn.SiLU"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(SiLU) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Sigmoid.cpp b/tools/pnnx/src/pass_level1/nn_Sigmoid.cpp new file mode 100644 index 000000000000..f106c553c024 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Sigmoid.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class Sigmoid : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Sigmoid"; + } + + const char* type_str() const + { + return "nn.Sigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Sigmoid) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Softmax.cpp b/tools/pnnx/src/pass_level1/nn_Softmax.cpp new file mode 100644 index 000000000000..f0baa0188682 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Softmax.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Softmax : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Softmax"; + } + + const char* type_str() const + { + return "nn.Softmax"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax"); + + op->params["dim"] = softmax->namedInput("dim"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Softmax) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Softmin.cpp b/tools/pnnx/src/pass_level1/nn_Softmin.cpp new file mode 100644 index 000000000000..1a1bec249923 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Softmin.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Softmin : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Softmin"; + } + + const char* type_str() const + { + return "nn.Softmin"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax"); + + op->params["dim"] = softmax->namedInput("dim"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Softmin) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Softplus.cpp b/tools/pnnx/src/pass_level1/nn_Softplus.cpp new file mode 100644 index 000000000000..cf470b0e76aa --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Softplus.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Softplus : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Softplus"; + } + + const char* type_str() const + { + return "nn.Softplus"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* softplus = find_node_by_kind(graph, "aten::softplus"); + + op->params["beta"] = softplus->namedInput("beta"); + op->params["threshold"] = softplus->namedInput("threshold"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Softplus) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Softshrink.cpp b/tools/pnnx/src/pass_level1/nn_Softshrink.cpp new file mode 100644 index 000000000000..4c16ef8146a8 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Softshrink.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Softshrink : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Softshrink"; + } + + const char* type_str() const + { + return "nn.Softshrink"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* softshrink = find_node_by_kind(graph, "aten::softshrink"); + + op->params["lambd"] = softshrink->namedInput("lambd"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Softshrink) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Softsign.cpp b/tools/pnnx/src/pass_level1/nn_Softsign.cpp new file mode 100644 index 000000000000..05be02954652 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Softsign.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class Softsign : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Softsign"; + } + + const char* type_str() const + { + return "nn.Softsign"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Softsign) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Tanh.cpp b/tools/pnnx/src/pass_level1/nn_Tanh.cpp new file mode 100644 index 000000000000..802b91c35fdb --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Tanh.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class Tanh : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Tanh"; + } + + const char* type_str() const + { + return "nn.Tanh"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Tanh) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp b/tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp new file mode 100644 index 000000000000..78f4f4d13ee3 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Tanhshrink.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +namespace pnnx { + +class Tanhshrink : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Tanhshrink"; + } + + const char* type_str() const + { + return "nn.Tanhshrink"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Tanhshrink) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Threshold.cpp b/tools/pnnx/src/pass_level1/nn_Threshold.cpp new file mode 100644 index 000000000000..ee52f6cbd63a --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Threshold.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Threshold : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.activation.Threshold"; + } + + const char* type_str() const + { + return "nn.Threshold"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* threshold = find_node_by_kind(graph, "aten::threshold"); + + op->params["threshold"] = threshold->namedInput("threshold"); + op->params["value"] = threshold->namedInput("value"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Threshold) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_Upsample.cpp b/tools/pnnx/src/pass_level1/nn_Upsample.cpp new file mode 100644 index 000000000000..471ffb50f039 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_Upsample.cpp @@ -0,0 +1,102 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Upsample : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.upsampling.Upsample"; + } + + const char* type_str() const + { + return "nn.Upsample"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* upsample_nearest1d = find_node_by_kind(graph, "aten::upsample_nearest1d"); + const torch::jit::Node* upsample_linear1d = find_node_by_kind(graph, "aten::upsample_linear1d"); + + const torch::jit::Node* upsample_nearest2d = find_node_by_kind(graph, "aten::upsample_nearest2d"); + const torch::jit::Node* upsample_bilinear2d = find_node_by_kind(graph, "aten::upsample_bilinear2d"); + const torch::jit::Node* upsample_bicubic2d = find_node_by_kind(graph, "aten::upsample_bicubic2d"); + + const torch::jit::Node* upsample_nearest3d = find_node_by_kind(graph, "aten::upsample_nearest3d"); + const torch::jit::Node* upsample_trilinear3d = find_node_by_kind(graph, "aten::upsample_trilinear3d"); + + const torch::jit::Node* upsample = 0; + if (upsample_nearest1d) + { + upsample = upsample_nearest1d; + op->params["mode"] = "nearest"; + } + else if (upsample_linear1d) + { + upsample = upsample_linear1d; + op->params["mode"] = "linear"; + } + else if (upsample_nearest2d) + { + upsample = upsample_nearest2d; + op->params["mode"] = "nearest"; + } + else if (upsample_bilinear2d) + { + upsample = upsample_bilinear2d; + op->params["mode"] = "bilinear"; + } + else if (upsample_bicubic2d) + { + upsample = upsample_bicubic2d; + op->params["mode"] = "bicubic"; + } + else if (upsample_nearest3d) + { + upsample = upsample_nearest3d; + op->params["mode"] = "nearest"; + } + else if (upsample_trilinear3d) + { + upsample = upsample_trilinear3d; + op->params["mode"] = "trilinear"; + } + + if (upsample->hasNamedInput("output_size")) + { + op->params["size"] = upsample->namedInput("output_size"); + } + + if (upsample->hasNamedInput("scale_factors")) + { + op->params["scale_factor"] = upsample->namedInput("scale_factors"); + } + + if (upsample->hasNamedInput("align_corners")) + { + op->params["align_corners"] = upsample->namedInput("align_corners"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Upsample) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp b/tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp new file mode 100644 index 000000000000..2b0441d44636 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_UpsamplingBilinear2d.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class UpsamplingBilinear2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.upsampling.UpsamplingBilinear2d"; + } + + const char* type_str() const + { + return "nn.UpsamplingBilinear2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* upsample = find_node_by_kind(graph, "aten::upsample_bilinear2d"); + + if (upsample->hasNamedInput("output_size")) + { + op->params["size"] = upsample->namedInput("output_size"); + } + + if (upsample->hasNamedInput("scale_factors")) + { + op->params["scale_factor"] = upsample->namedInput("scale_factors"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(UpsamplingBilinear2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp b/tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp new file mode 100644 index 000000000000..20f8e32fd4da --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_UpsamplingNearest2d.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class UpsamplingNearest2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.upsampling.UpsamplingNearest2d"; + } + + const char* type_str() const + { + return "nn.UpsamplingNearest2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* upsample = find_node_by_kind(graph, "aten::upsample_nearest2d"); + + if (upsample->hasNamedInput("output_size")) + { + op->params["size"] = upsample->namedInput("output_size"); + } + + if (upsample->hasNamedInput("scale_factors")) + { + op->params["scale_factor"] = upsample->namedInput("scale_factors"); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(UpsamplingNearest2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp b/tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp new file mode 100644 index 000000000000..88922b398908 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_ZeroPad2d.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class ZeroPad2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.padding.ZeroPad2d"; + } + + const char* type_str() const + { + return "nn.ZeroPad2d"; + } + + void write(Operator* op, const std::shared_ptr& graph) const + { + const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); + + op->params["padding"] = constant_pad_nd->namedInput("pad"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(ZeroPad2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_maxunpool2d.cpp b/tools/pnnx/src/pass_level1/nn_maxunpool2d.cpp new file mode 100644 index 000000000000..2a067a344679 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_maxunpool2d.cpp @@ -0,0 +1,80 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../pass_level3/fuse_expression.h" + +#include "../utils.h" + +namespace pnnx { + +class MaxUnpool2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.modules.pooling.MaxUnpool2d"; + } + + const char* type_str() const + { + return "nn.MaxUnpool2d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + graph->dump(); + + { + Graph pnnx_graph; + + pass_level1(mod, graph, pnnx_graph); + + fuse_expression(pnnx_graph); + + Operator* expr_op = pnnx_graph.ops[2]; + + if (expr_op->type == "pnnx.Expression") + { + std::string expr = expr_op->params["expr"].s; + + int stride0; + int stride1; + int kernel_size0; + int kernel_size1; + int padding0; + int padding1; + int nscan = sscanf(expr.c_str(), "(int(sub(add(mul(sub(size(@0,2),1),%d),%d),%d)),int(sub(add(mul(sub(size(@1,3),1),%d),%d),%d)))", &stride0, &kernel_size0, &padding0, &stride1, &kernel_size1, &padding1); + if (nscan == 6) + { + op->params["kernel_size"] = Parameter{kernel_size0, kernel_size1}; + op->params["stride"] = Parameter{stride0, stride1}; + op->params["padding"] = Parameter{padding0 / 2, padding1 / 2}; + } + } + } + + const torch::jit::Node* max_unpool2d = find_node_by_kind(graph, "aten::max_unpool2d"); + + for (auto aa : max_unpool2d->schema().arguments()) + { + fprintf(stderr, "arg %s\n", aa.name().c_str()); + } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(MaxUnpool2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp b/tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp new file mode 100644 index 000000000000..f57dfcd458c4 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_quantized_Conv2d.cpp @@ -0,0 +1,185 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class QuantizedConv2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.quantized.modules.conv.Conv2d"; + } + + const char* type_str() const + { + return "nn.quantized.Conv2d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // graph->dump(); + + const torch::jit::Node* quantized_convolution = find_node_by_kind(graph, "quantized::conv2d"); + + // for (auto aa : quantized_convolution->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + // torch::jit::Node* packed_params_node = 0; + // for (const auto& n : graph->nodes()) + // { + // if (n->kind() == c10::prim::GetAttr && n->s(torch::jit::attr::name) == "_packed_params") + // { + // packed_params_node = n; + // break; + // } + // } + + // quantized_convolution->namedInput("output_scale"); + + const auto& packed_params = mod.attr("_packed_params").toObject(); + + // auto x = torch::jit::script::Object(packed_params).run_method("__getstate__"); + auto x = torch::jit::script::Object(packed_params).run_method("unpack").toTuple(); + // std::cout << x->elements()[0].toTensor() << std::endl; + // std::cout << x->elements()[0].toTensor().quantizer() << std::endl; + // std::cout << x->elements()[1] << std::endl; + // at::Tensor dequantize() const; + // double q_scale() const; + // int64_t q_zero_point() const; + // at::Tensor q_per_channel_scales() const; + // at::Tensor q_per_channel_zero_points() const; + // int64_t q_per_channel_axis() const; + + // auto quantizer = x->elements()[0].toTensor().quantizer(); + + auto weight = x->elements()[0].toTensor(); + auto bias = x->elements()[1].toTensor(); + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + + if (weight.qscheme() == c10::kPerChannelAffine) + { + op->attrs["weight.q_per_channel_scales"] = weight.q_per_channel_scales(); + op->attrs["weight.q_per_channel_zero_points"] = weight.q_per_channel_zero_points(); + // op->params["weight.q_per_channel_axis"] = weight.q_per_channel_axis(); + } + + op->params["in_channels"] = mod.attr("in_channels").toInt(); + op->params["out_channels"] = mod.attr("out_channels").toInt(); + op->params["kernel_size"] = Parameter{mod.attr("kernel_size").toTuple()->elements()[0].toInt(), mod.attr("kernel_size").toTuple()->elements()[1].toInt()}; + op->params["stride"] = Parameter{mod.attr("stride").toTuple()->elements()[0].toInt(), mod.attr("stride").toTuple()->elements()[1].toInt()}; + op->params["padding"] = Parameter{mod.attr("padding").toTuple()->elements()[0].toInt(), mod.attr("padding").toTuple()->elements()[1].toInt()}; + op->params["dilation"] = Parameter{mod.attr("dilation").toTuple()->elements()[0].toInt(), mod.attr("dilation").toTuple()->elements()[1].toInt()}; + op->params["groups"] = mod.attr("groups").toInt(); + op->params["padding_mode"] = "zeros"; + op->params["bias"] = mod.hasattr("bias"); + + op->params["scale"] = quantized_convolution->namedInput("output_scale"); + op->params["zero_point"] = quantized_convolution->namedInput("output_zero_point"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(QuantizedConv2d) + +class QuantizedConvReLU2d : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d"; + } + + const char* type_str() const + { + return "nn.intrinsic.quantized.ConvReLU2d"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // graph->dump(); + + const torch::jit::Node* quantized_convolution = find_node_by_kind(graph, "quantized::conv2d_relu"); + + // for (auto aa : quantized_convolution->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + // torch::jit::Node* packed_params_node = 0; + // for (const auto& n : graph->nodes()) + // { + // if (n->kind() == c10::prim::GetAttr && n->s(torch::jit::attr::name) == "_packed_params") + // { + // packed_params_node = n; + // break; + // } + // } + + // quantized_convolution->namedInput("output_scale"); + + const auto& packed_params = mod.attr("_packed_params").toObject(); + + // auto x = torch::jit::script::Object(packed_params).run_method("__getstate__"); + auto x = torch::jit::script::Object(packed_params).run_method("unpack").toTuple(); + // std::cout << x->elements()[0].toTensor() << std::endl; + // std::cout << x->elements()[0].toTensor().quantizer() << std::endl; + // std::cout << x->elements()[1] << std::endl; + // at::Tensor dequantize() const; + // double q_scale() const; + // int64_t q_zero_point() const; + // at::Tensor q_per_channel_scales() const; + // at::Tensor q_per_channel_zero_points() const; + // int64_t q_per_channel_axis() const; + + // auto quantizer = x->elements()[0].toTensor().quantizer(); + + auto weight = x->elements()[0].toTensor(); + auto bias = x->elements()[1].toTensor(); + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + + if (weight.qscheme() == c10::kPerChannelAffine) + { + op->attrs["weight.q_per_channel_scales"] = weight.q_per_channel_scales(); + op->attrs["weight.q_per_channel_zero_points"] = weight.q_per_channel_zero_points(); + // op->params["weight.q_per_channel_axis"] = weight.q_per_channel_axis(); + } + + op->params["in_channels"] = mod.attr("in_channels").toInt(); + op->params["out_channels"] = mod.attr("out_channels").toInt(); + op->params["kernel_size"] = Parameter{mod.attr("kernel_size").toTuple()->elements()[0].toInt(), mod.attr("kernel_size").toTuple()->elements()[1].toInt()}; + op->params["stride"] = Parameter{mod.attr("stride").toTuple()->elements()[0].toInt(), mod.attr("stride").toTuple()->elements()[1].toInt()}; + op->params["padding"] = Parameter{mod.attr("padding").toTuple()->elements()[0].toInt(), mod.attr("padding").toTuple()->elements()[1].toInt()}; + op->params["dilation"] = Parameter{mod.attr("dilation").toTuple()->elements()[0].toInt(), mod.attr("dilation").toTuple()->elements()[1].toInt()}; + op->params["groups"] = mod.attr("groups").toInt(); + op->params["padding_mode"] = "zeros"; + op->params["bias"] = mod.hasattr("bias"); + + op->params["scale"] = quantized_convolution->namedInput("output_scale"); + op->params["zero_point"] = quantized_convolution->namedInput("output_zero_point"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(QuantizedConvReLU2d) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp b/tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp new file mode 100644 index 000000000000..15100839df7c --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_quantized_DeQuantize.cpp @@ -0,0 +1,51 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class DeQuantize : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.quantized.modules.DeQuantize"; + } + + const char* type_str() const + { + return "nn.quantized.DeQuantize"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // mod.dump(true, false, false); + + // graph->dump(); + + const torch::jit::Node* dequantize = find_node_by_kind(graph, "aten::dequantize"); + + // for (auto aa : dequantize->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(DeQuantize) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp b/tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp new file mode 100644 index 000000000000..3dfcab8abbe6 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_quantized_Linear.cpp @@ -0,0 +1,91 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class QuantizedLinear : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.quantized.modules.linear.Linear"; + } + + const char* type_str() const + { + return "nn.quantized.Linear"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // mod.dump(true, false, false); + + // graph->dump(); + + const torch::jit::Node* quantized_linear = find_node_by_kind(graph, "quantized::linear"); + + // for (auto aa : quantized_linear->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + // torch::jit::Node* packed_params_node = 0; + // for (const auto& n : graph->nodes()) + // { + // if (n->kind() == c10::prim::GetAttr && n->s(torch::jit::attr::name) == "_packed_params") + // { + // packed_params_node = n; + // break; + // } + // } + + const auto& packed_params = mod.attr("_packed_params").toObject(); + + // for (auto aa : torch::jit::script::Object(packed_params).get_methods()) + // { + // fprintf(stderr, "method %s\n", aa.name().c_str()); + // } + + auto x = torch::jit::script::Object(packed_params).run_method("_weight_bias").toTuple(); + + auto weight = x->elements()[0].toTensor(); + auto bias = x->elements()[1].toTensor(); + + op->attrs["weight"] = weight; + op->attrs["bias"] = bias; + + if (weight.qscheme() == c10::kPerChannelAffine) + { + op->attrs["weight.q_per_channel_scales"] = weight.q_per_channel_scales(); + op->attrs["weight.q_per_channel_zero_points"] = weight.q_per_channel_zero_points(); + // op->params["weight.q_per_channel_axis"] = weight.q_per_channel_axis(); + } + + op->params["in_features"] = weight.size(1); + op->params["out_features"] = weight.size(0); + // op->params["in_features"] = mod.attr("in_features").toInt(); + // op->params["out_features"] = mod.attr("out_features").toInt(); + + op->params["scale"] = quantized_linear->namedInput("Y_scale_i"); + op->params["zero_point"] = quantized_linear->namedInput("Y_zero_point_i"); + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(QuantizedLinear) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp b/tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp new file mode 100644 index 000000000000..f19a02c00eb2 --- /dev/null +++ b/tools/pnnx/src/pass_level1/nn_quantized_Quantize.cpp @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level1.h" + +#include "../utils.h" + +namespace pnnx { + +class Quantize : public FuseModulePass +{ +public: + const char* match_type_str() const + { + return "__torch__.torch.nn.quantized.modules.Quantize"; + } + + const char* type_str() const + { + return "nn.quantized.Quantize"; + } + + void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const + { + // mod.dump(true, false, false); + + // graph->dump(); + + const torch::jit::Node* quantize_per_tensor = find_node_by_kind(graph, "aten::quantize_per_tensor"); + + // for (auto aa : quantize_per_tensor->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } + + // scale, zero_point + op->params["scale"] = quantize_per_tensor->namedInput("scale"); + op->params["zero_point"] = quantize_per_tensor->namedInput("zero_point"); + op->params["dtype"] = "torch.qint8"; + } +}; + +REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Quantize) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp new file mode 100644 index 000000000000..708611dbd7ec --- /dev/null +++ b/tools/pnnx/src/pass_level2.cpp @@ -0,0 +1,466 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +#include +#include +#include + +namespace pnnx { + +GraphRewriterPass::~GraphRewriterPass() +{ +} + +const char* GraphRewriterPass::name_str() const +{ + return type_str(); +} + +bool GraphRewriterPass::match(const std::map& /*captured_params*/) const +{ + return true; +} + +bool GraphRewriterPass::match(const std::map& captured_params, const std::map& /*captured_attrs*/) const +{ + return match(captured_params); +} + +void GraphRewriterPass::write(Operator* op, const std::map& captured_params) const +{ + for (auto x : captured_params) + { + op->params[x.first] = x.second; + } +} + +void GraphRewriterPass::write(Operator* op, const std::map& captured_params, const std::map& /*captured_attrs*/) const +{ + write(op, captured_params); +} + +static std::map > g_global_pnnx_graph_rewriter_passes; + +GraphRewriterPassRegister::GraphRewriterPassRegister(const GraphRewriterPass* _pass, int priority) + : pass(_pass) +{ + if (g_global_pnnx_graph_rewriter_passes.find(priority) == g_global_pnnx_graph_rewriter_passes.end()) + { + g_global_pnnx_graph_rewriter_passes[priority] = std::vector(); + } + + g_global_pnnx_graph_rewriter_passes[priority].push_back(pass); +} + +GraphRewriterPassRegister::~GraphRewriterPassRegister() +{ + delete pass; +} + +static bool match_parameter(const Parameter& a, const Parameter& b, std::map& captured_params) +{ + if (b.type == 4 && b.s[0] == '%') + { + // captured parameter + captured_params[b.s.substr(1)] = a; + return true; + } + + if (b.type == 4 && b.s == "*") + { + // ignored parameter + return true; + } + + if (a.type != b.type) + return false; + + const int type = a.type; + + if (type == 0) + { + return true; + } + if (type == 1) + { + return a.b == b.b; + } + if (type == 2) + { + return a.i == b.i; + } + if (type == 3) + { + return a.f == b.f; + } + if (type == 4) + { + return a.s == b.s; + } + if (type == 5) + { + if (a.ai.size() != b.ai.size()) + return false; + + for (size_t i = 0; i < a.ai.size(); i++) + { + if (a.ai[i] != b.ai[i]) + return false; + } + + return true; + } + if (type == 6) + { + if (a.af.size() != b.af.size()) + return false; + + for (size_t i = 0; i < a.af.size(); i++) + { + if (a.af[i] != b.af[i]) + return false; + } + + return true; + } + if (type == 7) + { + if (a.as.size() != b.as.size()) + return false; + + for (size_t i = 0; i < a.as.size(); i++) + { + if (a.as[i] != b.as[i]) + return false; + } + + return true; + } + + // unknown + return false; +} + +static bool match_operator(const Operator* a, const Operator* b, std::map& captured_params, std::map& captured_attrs) +{ + if (a->type != b->type) + return false; + + if (a->inputs.size() != b->inputs.size()) + return false; + + if (a->outputs.size() != b->outputs.size()) + return false; + + // match params + if (b->params.size() == 1 && b->params.find("%*") != b->params.end() && b->params.at("%*").type == 4 && b->params.at("%*").s == "%*") + { + for (const auto& p : a->params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + // capture all parameters + captured_params[b->name + '.' + pkey] = pp; + } + } + else + { + if (a->params.size() != b->params.size()) + return false; + + for (const auto& p : a->params) + { + const std::string& akey = p.first; + const Parameter& ap = p.second; + + if (b->params.find(akey) == b->params.end()) + return false; + + if (!match_parameter(ap, b->params.at(akey), captured_params)) + return false; + } + } + + for (const auto& p : a->attrs) + { + const std::string& akey = p.first; + const Attribute& aa = p.second; + + // capture all attributes + captured_attrs[b->name + '.' + akey] = aa; + } + + return true; +} + +static bool match(const Operator* anchor, const Operator* pattern, std::unordered_map& matched_operators, std::unordered_map& matched_inputs, std::map& captured_params, std::map& captured_attrs) +{ + if (!match_operator(anchor, pattern, captured_params, captured_attrs)) + return false; + + for (size_t i = 0; i < pattern->outputs.size(); i++) + { + if (pattern->outputs[i]->consumers.size() == 1 && pattern->outputs[i]->consumers[0]->type == "pnnx.Output") + continue; + + if (anchor->outputs[i]->consumers.size() != pattern->outputs[i]->consumers.size()) + return false; + } + + matched_operators[pattern->name] = anchor; + + // lets match + for (size_t i = 0; i < pattern->inputs.size(); i++) + { + const Operator* anchor2 = anchor->inputs[i]->producer; + const Operator* pattern2 = pattern->inputs[i]->producer; + + if (pattern2->type == "pnnx.Input") + { + if (matched_inputs.find(pattern->inputs[i]->name) == matched_inputs.end()) + { + matched_inputs[pattern->inputs[i]->name] = anchor->inputs[i]; + } + continue; + } + + if (!match(anchor2, pattern2, matched_operators, matched_inputs, captured_params, captured_attrs)) + return false; + } + + return true; +} + +void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opindex) +{ + Graph pattern_graph; + pattern_graph.parse(pass->match_pattern_graph()); + + // collect pattern inputs and outputs order + std::vector pattern_graph_inputs; + std::vector pattern_graph_outputs; + std::vector pattern_graph_output_operators; + for (const auto& x : pattern_graph.ops) + { + if (x->type == "pnnx.Input") + { + for (const auto& y : x->outputs) + pattern_graph_inputs.push_back(y->name); + } + if (x->type == "pnnx.Output") + { + pattern_graph_output_operators.push_back(x); + for (const auto& y : x->inputs) + pattern_graph_outputs.push_back(y->name); + } + } + + std::vector new_ops; + + while (1) + { + const int graph_op_count = (int)graph.ops.size(); + + bool matched = true; + + // lets match from output + std::unordered_map matched_operators; + std::unordered_map matched_inputs; + std::unordered_map matched_outputs; + std::map captured_params; + std::map captured_attrs; + + // pattern match from end to beginning + int q = graph_op_count - 1; + for (; q >= 1; q--) + { + for (const Operator* pattern : pattern_graph_output_operators) + { + for (size_t i = 0; i < pattern->inputs.size(); i++) + { + const Operator* pattern2 = pattern->inputs[i]->producer; + + int j = q; + for (; j >= 0; j--) + { + const Operator* anchor = graph.ops[j]; + + std::unordered_map matched_operators2; + std::unordered_map matched_inputs2; + std::map captured_params2; + std::map captured_attrs2; + if (!match(anchor, pattern2, matched_operators2, matched_inputs2, captured_params2, captured_attrs2)) + continue; + + bool submatch_matched = true; + for (auto x : matched_operators2) + { + // check these matched operators are same with previous matched ones + if (matched_operators.find(x.first) != matched_operators.end()) + { + if (matched_operators[x.first] != x.second) + { + // unmatched two sub-matches + submatch_matched = false; + break; + } + } + else + { + matched_operators[x.first] = x.second; + } + } + + if (!submatch_matched) + continue; + + for (auto x : matched_inputs2) + { + if (matched_inputs.find(x.first) == matched_inputs.end()) + { + matched_inputs[x.first] = x.second; + } + } + for (auto x : captured_params2) + { + captured_params[x.first] = x.second; + } + for (auto x : captured_attrs2) + { + captured_attrs[x.first] = x.second; + } + + // match ! + matched_outputs[pattern->inputs[i]->name] = anchor->outputs[i]; + break; + } + + if (j == -1) + { + matched = false; + break; + } + } + + if (!matched) + break; + } + + if (matched && !pass->match(captured_params, captured_attrs)) + { + matched_operators.clear(); + matched_inputs.clear(); + matched_outputs.clear(); + captured_params.clear(); + captured_attrs.clear(); + continue; + } + + break; + } + + if (!matched) + break; + + // fprintf(stderr, "matched !\n"); + + // lets replace + + // remove all matched_operators + for (auto& _x : matched_operators) + { + // fprintf(stderr, "remove %s\n", _x.second->name.c_str()); + + Operator* x = (Operator*)_x.second; + for (auto& r : x->inputs) + { + r->remove_consumer(x); + } + + x->inputs.clear(); + + for (auto& r : x->outputs) + { + r->producer = 0; + } + + x->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), x)); + + delete _x.second; + } + + // insert new operator before all output consumers + const Operator* cur = 0; + { + int cur_index = graph.ops.size() - 1; + for (auto& o : matched_outputs) + { + for (auto& c : o.second->consumers) + { + int c_index = std::find(graph.ops.begin(), graph.ops.end(), c) - graph.ops.begin(); + cur_index = std::min(cur_index, c_index); + } + } + + cur = graph.ops[cur_index]; + } + + Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur); + + for (const auto& k : pattern_graph_inputs) + { + Operand* r = (Operand*)matched_inputs.at(k); + r->consumers.push_back(op); + op->inputs.push_back(r); + + op->inputnames.push_back(k); + } + + for (const auto& k : pattern_graph_outputs) + { + Operand* r = (Operand*)matched_outputs.at(k); + r->producer = op; + op->outputs.push_back(r); + } + + pass->write(op, captured_params, captured_attrs); + + new_ops.push_back(op); + } + + // assign new op name number + for (int i = (int)new_ops.size() - 1; i >= 0; i--) + { + new_ops[i]->name = new_ops[i]->name + "_" + std::to_string(opindex++); + } +} + +void pass_level2(Graph& g) +{ + int opindex = 0; + for (auto x : g_global_pnnx_graph_rewriter_passes) + { + for (auto rewriter : x.second) + { + pnnx_graph_rewrite(g, rewriter, opindex); + } + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2.h b/tools/pnnx/src/pass_level2.h new file mode 100644 index 000000000000..1a0562be939d --- /dev/null +++ b/tools/pnnx/src/pass_level2.h @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_PASS_LEVEL2_H +#define PNNX_PASS_LEVEL2_H + +#include "ir.h" + +namespace pnnx { + +class GraphRewriterPass +{ +public: + virtual ~GraphRewriterPass(); + + virtual const char* match_pattern_graph() const = 0; + + virtual const char* type_str() const = 0; + + virtual const char* name_str() const; + + virtual bool match(const std::map& captured_params) const; + + virtual bool match(const std::map& captured_params, const std::map& captured_attrs) const; + + virtual void write(Operator* op, const std::map& captured_params) const; + + virtual void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const; +}; + +class GraphRewriterPassRegister +{ +public: + GraphRewriterPassRegister(const GraphRewriterPass* pass, int priority); + ~GraphRewriterPassRegister(); + const GraphRewriterPass* pass; +}; + +#define REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(CLASS, PRIORITY) \ + static GraphRewriterPassRegister g_global_pnnx_graphrewriterpass_##CLASS##_register(new CLASS, PRIORITY); + +void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opindex); + +void pass_level2(Graph& g); + +} // namespace pnnx + +#endif // PNNX_PASS_LEVEL2_H diff --git a/tools/pnnx/src/pass_level2/F_adaptive_avg_pool1d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool1d.cpp new file mode 100644 index 000000000000..a110bbe36c0f --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool1d.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_adaptive_avg_pool1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 output_size +aten::adaptive_avg_pool1d op_0 2 1 input output_size out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_avg_pool1d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool1d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp new file mode 100644 index 000000000000..7938929497a2 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_adaptive_avg_pool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 output_size +aten::adaptive_avg_pool2d op_0 2 1 input output_size out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_avg_pool2d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_adaptive_avg_pool3d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool3d.cpp new file mode 100644 index 000000000000..8485c36dc079 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool3d.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_adaptive_avg_pool3d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 output_size +aten::adaptive_avg_pool3d op_0 2 1 input output_size out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_avg_pool3d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool3d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_adaptive_max_pool1d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_max_pool1d.cpp new file mode 100644 index 000000000000..5fa902c2ce5b --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_adaptive_max_pool1d.cpp @@ -0,0 +1,46 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_adaptive_max_pool1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 output_size +aten::adaptive_max_pool1d op_0 2 2 input output_size out indices +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_max_pool1d"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["return_indices"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool1d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp new file mode 100644 index 000000000000..3a38f329b3a1 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp @@ -0,0 +1,46 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_adaptive_max_pool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 output_size +aten::adaptive_max_pool2d op_0 2 2 input output_size out indices +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_max_pool2d"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["return_indices"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_adaptive_max_pool3d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_max_pool3d.cpp new file mode 100644 index 000000000000..c1ed83d9fd66 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_adaptive_max_pool3d.cpp @@ -0,0 +1,46 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_adaptive_max_pool3d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 output_size +aten::adaptive_max_pool3d op_0 2 2 input output_size out indices +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_max_pool3d"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["return_indices"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool3d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_affine_grid.cpp b/tools/pnnx/src/pass_level2/F_affine_grid.cpp new file mode 100644 index 000000000000..d766da701229 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_affine_grid.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_affine_grid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_1 0 1 theta +pnnx.Input input_2 0 1 size +prim::Constant op_0 0 1 align_corners value=%align_corners +aten::affine_grid_generator op_1 3 1 theta size align_corners out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.affine_grid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_affine_grid, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_avg_pool1d.cpp b/tools/pnnx/src/pass_level2/F_avg_pool1d.cpp new file mode 100644 index 000000000000..16bbc7944743 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_avg_pool1d.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_avg_pool1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +prim::Constant op_1 0 1 count_include_pad value=%count_include_pad +aten::avg_pool1d op_2 6 1 input kernel_size stride padding ceil_mode count_include_pad out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.avg_pool1d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool1d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp b/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp new file mode 100644 index 000000000000..6c92b31808a4 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp @@ -0,0 +1,46 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_avg_pool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +prim::Constant op_1 0 1 count_include_pad value=%count_include_pad +prim::Constant op_2 0 1 divisor_override value=%divisor_override +aten::avg_pool2d op_3 7 1 input kernel_size stride padding ceil_mode count_include_pad divisor_override out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.avg_pool2d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_avg_pool3d.cpp b/tools/pnnx/src/pass_level2/F_avg_pool3d.cpp new file mode 100644 index 000000000000..569d963bfb68 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_avg_pool3d.cpp @@ -0,0 +1,46 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_avg_pool3d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +prim::Constant op_1 0 1 count_include_pad value=%count_include_pad +prim::Constant op_2 0 1 divisor_override value=%divisor_override +aten::avg_pool3d op_3 7 1 input kernel_size stride padding ceil_mode count_include_pad divisor_override out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.avg_pool3d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool3d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_batch_norm.cpp b/tools/pnnx/src/pass_level2/F_batch_norm.cpp new file mode 100644 index 000000000000..e922e458a065 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_batch_norm.cpp @@ -0,0 +1,48 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_batch_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +11 10 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 running_mean +pnnx.Input input_2 0 1 running_var +pnnx.Input input_3 0 1 weight +pnnx.Input input_4 0 1 bias +prim::Constant op_0 0 1 training value=* +prim::Constant op_1 0 1 momentum value=* +prim::Constant op_2 0 1 eps value=%eps +prim::Constant op_3 0 1 cudnn_enabled value=* +aten::batch_norm op_4 9 1 input weight bias running_mean running_var training momentum eps cudnn_enabled out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.batch_norm"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_celu.cpp b/tools/pnnx/src/pass_level2/F_celu.cpp new file mode 100644 index 000000000000..0e13fb181d68 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_celu.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_celu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 alpha +aten::celu op_0 2 1 input alpha out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.celu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_celu, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv1d.cpp b/tools/pnnx/src/pass_level2/F_conv1d.cpp new file mode 100644 index 000000000000..5a8c20650cbc --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_conv1d.cpp @@ -0,0 +1,80 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_conv1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +16 15 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 groups +prim::Constant op_0 0 1 transposed value=False +prim::Constant op_1 0 1 output_padding_w value=0 +prim::ListConstruct op_2 1 1 output_padding_w output_padding +prim::Constant op_3 0 1 benchmark value=* +prim::Constant op_4 0 1 deterministic value=* +prim::Constant op_5 0 1 cudnn_enabled value=* +prim::Constant op_6 0 1 allow_tf32 value=* +aten::_convolution op_7 13 1 input weight bias stride padding dilation transposed output_padding groups benchmark deterministic cudnn_enabled allow_tf32 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv1d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d, 10) + +class F_conv1d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 groups +aten::_convolution_mode op_0 7 1 input weight bias stride padding dilation groups out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv1d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_1, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv2d.cpp b/tools/pnnx/src/pass_level2/F_conv2d.cpp new file mode 100644 index 000000000000..fca5d3f7e155 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_conv2d.cpp @@ -0,0 +1,81 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_conv2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +17 16 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 groups +prim::Constant op_0 0 1 transposed value=False +prim::Constant op_1 0 1 output_padding_h value=0 +prim::Constant op_2 0 1 output_padding_w value=0 +prim::ListConstruct op_3 2 1 output_padding_h output_padding_w output_padding +prim::Constant op_4 0 1 benchmark value=* +prim::Constant op_5 0 1 deterministic value=* +prim::Constant op_6 0 1 cudnn_enabled value=* +prim::Constant op_7 0 1 allow_tf32 value=* +aten::_convolution op_8 13 1 input weight bias stride padding dilation transposed output_padding groups benchmark deterministic cudnn_enabled allow_tf32 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv2d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d, 10) + +class F_conv2d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 groups +aten::_convolution_mode op_0 7 1 input weight bias stride padding dilation groups out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv2d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_1, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv3d.cpp b/tools/pnnx/src/pass_level2/F_conv3d.cpp new file mode 100644 index 000000000000..0a6565a8904c --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_conv3d.cpp @@ -0,0 +1,82 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_conv3d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +18 17 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 groups +prim::Constant op_0 0 1 transposed value=False +prim::Constant op_1 0 1 output_padding_d value=0 +prim::Constant op_2 0 1 output_padding_h value=0 +prim::Constant op_3 0 1 output_padding_w value=0 +prim::ListConstruct op_4 3 1 output_padding_d output_padding_h output_padding_w output_padding +prim::Constant op_5 0 1 benchmark value=* +prim::Constant op_6 0 1 deterministic value=* +prim::Constant op_7 0 1 cudnn_enabled value=* +prim::Constant op_8 0 1 allow_tf32 value=* +aten::_convolution op_9 13 1 input weight bias stride padding dilation transposed output_padding groups benchmark deterministic cudnn_enabled allow_tf32 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv3d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv3d, 10) + +class F_conv3d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 groups +aten::_convolution_mode op_0 7 1 input weight bias stride padding dilation groups out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv3d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv3d_1, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv_transpose1d.cpp b/tools/pnnx/src/pass_level2/F_conv_transpose1d.cpp new file mode 100644 index 000000000000..a7825faacce1 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_conv_transpose1d.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_conv_transpose1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +15 14 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 output_padding +pnnx.Input input_7 0 1 groups +prim::Constant op_0 0 1 transposed value=True +prim::Constant op_1 0 1 benchmark value=* +prim::Constant op_2 0 1 deterministic value=* +prim::Constant op_3 0 1 cudnn_enabled value=* +prim::Constant op_4 0 1 allow_tf32 value=* +aten::_convolution op_5 13 1 input weight bias stride padding dilation transposed output_padding groups benchmark deterministic cudnn_enabled allow_tf32 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv_transpose1d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv_transpose1d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv_transpose2d.cpp b/tools/pnnx/src/pass_level2/F_conv_transpose2d.cpp new file mode 100644 index 000000000000..249e535329d2 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_conv_transpose2d.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_conv_transpose2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +15 14 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 output_padding +pnnx.Input input_7 0 1 groups +prim::Constant op_0 0 1 transposed value=True +prim::Constant op_1 0 1 benchmark value=* +prim::Constant op_2 0 1 deterministic value=* +prim::Constant op_3 0 1 cudnn_enabled value=* +prim::Constant op_4 0 1 allow_tf32 value=* +aten::_convolution op_5 13 1 input weight bias stride padding dilation transposed output_padding groups benchmark deterministic cudnn_enabled allow_tf32 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv_transpose2d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv_transpose2d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv_transpose3d.cpp b/tools/pnnx/src/pass_level2/F_conv_transpose3d.cpp new file mode 100644 index 000000000000..e502f140434e --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_conv_transpose3d.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_conv_transpose3d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +15 14 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 stride +pnnx.Input input_4 0 1 padding +pnnx.Input input_5 0 1 dilation +pnnx.Input input_6 0 1 output_padding +pnnx.Input input_7 0 1 groups +prim::Constant op_0 0 1 transposed value=True +prim::Constant op_1 0 1 benchmark value=* +prim::Constant op_2 0 1 deterministic value=* +prim::Constant op_3 0 1 cudnn_enabled value=* +prim::Constant op_4 0 1 allow_tf32 value=* +aten::_convolution op_5 13 1 input weight bias stride padding dilation transposed output_padding groups benchmark deterministic cudnn_enabled allow_tf32 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv_transpose3d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv_transpose3d, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_elu.cpp b/tools/pnnx/src/pass_level2/F_elu.cpp new file mode 100644 index 000000000000..667d14fa6339 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_elu.cpp @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_elu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 alpha +prim::Constant op_0 0 1 scale value=1 +prim::Constant op_1 0 1 input_scale value=1 +aten::elu op_2 4 1 input alpha scale input_scale out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.elu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_elu, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_gelu.cpp b/tools/pnnx/src/pass_level2/F_gelu.cpp new file mode 100644 index 000000000000..55385c0897ba --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_gelu.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_gelu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::gelu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.gelu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_grid_sample.cpp b/tools/pnnx/src/pass_level2/F_grid_sample.cpp new file mode 100644 index 000000000000..c023fc9386c3 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_grid_sample.cpp @@ -0,0 +1,63 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_grid_sample : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_1 0 1 input +pnnx.Input input_2 0 1 grid +prim::Constant op_0 0 1 mode value=%mode +prim::Constant op_1 0 1 padding_mode value=%padding_mode +prim::Constant op_2 0 1 align_corners value=%align_corners +aten::grid_sampler op_3 5 1 input grid mode padding_mode align_corners out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.grid_sample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.at("mode").i == 0) + op->params["mode"] = "bilinear"; + if (captured_params.at("mode").i == 1) + op->params["mode"] = "nearest"; + if (captured_params.at("mode").i == 2) + op->params["mode"] = "bicubic"; + + if (captured_params.at("padding_mode").i == 0) + op->params["padding_mode"] = "zeros"; + if (captured_params.at("padding_mode").i == 1) + op->params["padding_mode"] = "border"; + if (captured_params.at("padding_mode").i == 2) + op->params["padding_mode"] = "reflection"; + + op->params["align_corners"] = captured_params.at("align_corners"); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_grid_sample, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_group_norm.cpp b/tools/pnnx/src/pass_level2/F_group_norm.cpp new file mode 100644 index 000000000000..94608736577f --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_group_norm.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_group_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_1 0 1 input +pnnx.Input input_2 0 1 weight +pnnx.Input input_3 0 1 bias +prim::Constant op_0 0 1 num_groups value=%num_groups +prim::Constant op_1 0 1 eps value=%eps +prim::Constant op_2 0 1 cudnn_enabled value=* +aten::group_norm op_3 6 1 input num_groups weight bias eps cudnn_enabled out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.group_norm"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_group_norm, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_hardshrink.cpp b/tools/pnnx/src/pass_level2/F_hardshrink.cpp new file mode 100644 index 000000000000..1907d9471e12 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_hardshrink.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_hardshrink : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 lambd +aten::hardshrink op_0 2 1 input lambd out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardshrink"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardshrink, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_hardsigmoid.cpp b/tools/pnnx/src/pass_level2/F_hardsigmoid.cpp new file mode 100644 index 000000000000..5bfc4c5a824c --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_hardsigmoid.cpp @@ -0,0 +1,143 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_hardsigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::hardsigmoid op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardsigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid, 10) + +class F_hardsigmoid_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::hardsigmoid_ op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardsigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_1, 10) + +class F_hardsigmoid_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 410 value=3.000000e+00 +prim::Constant op_1 0 1 412 value=1 +aten::add_ op_2 3 1 input 410 412 a +prim::Constant op_3 0 1 413 value=0.000000e+00 +prim::Constant op_4 0 1 414 value=6.000000e+00 +aten::clamp_ op_5 3 1 a 413 414 b +prim::Constant op_6 0 1 409 value=6.000000e+00 +aten::div_ op_7 2 1 b 409 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardsigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_2, 9) + +class F_hardsigmoid_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 12 value=3.000000e+00 +prim::Constant op_1 0 1 13 value=1 +aten::add op_2 3 1 input 12 13 a +prim::Constant op_3 0 1 16 value=0.000000e+00 +prim::Constant op_4 0 1 17 value=6.000000e+00 +aten::hardtanh_ op_5 3 1 a 16 17 b +prim::Constant op_6 0 1 19 value=6.000000e+00 +aten::div op_7 2 1 b 19 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardsigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_3, 9) + +class F_hardsigmoid_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 12 value=3.000000e+00 +prim::Constant op_1 0 1 13 value=1 +aten::add op_2 3 1 input 12 13 a +aten::relu6_ op_3 1 1 a b +prim::Constant op_4 0 1 19 value=6.000000e+00 +aten::div op_5 2 1 b 19 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardsigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardsigmoid_4, 9) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_hardswish.cpp b/tools/pnnx/src/pass_level2/F_hardswish.cpp new file mode 100644 index 000000000000..c68ec57531c0 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_hardswish.cpp @@ -0,0 +1,176 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_hardswish : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::hardswish op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish, 10) + +class F_hardswish_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +11 10 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 392 value=3 +prim::Constant op_1 0 1 393 value=1 +aten::add op_2 3 1 input 392 393 a +prim::Constant op_3 0 1 394 value=0.000000e+00 +prim::Constant op_4 0 1 395 value=6.000000e+00 +aten::hardtanh_ op_5 3 1 a 394 395 b +aten::mul op_6 2 1 input b c +prim::Constant op_7 0 1 391 value=6 +aten::div op_8 2 1 c 391 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_1, 8) + +class F_hardswish_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +aten::hardsigmoid op_0 1 1 input a +aten::mul op_1 2 1 input a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_2, 9) + +class F_hardswish_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +11 10 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 12 value=3 +prim::Constant op_1 0 1 13 value=1 +aten::add op_2 3 1 input 12 13 a +prim::Constant op_3 0 1 17 value=0.000000e+00 +prim::Constant op_4 0 1 18 value=6.000000e+00 +aten::hardtanh op_5 3 1 a 17 18 b +aten::mul op_6 2 1 input b c +prim::Constant op_7 0 1 22 value=6.000000e+00 +aten::div op_8 2 1 c 22 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_3, 8) + +class F_hardswish_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +11 10 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 25 value=3.000000e+00 +prim::Constant op_1 0 1 47 value=1 +aten::add op_2 3 1 input 25 47 a +prim::Constant op_3 0 1 48 value=0.000000e+00 +prim::Constant op_4 0 1 49 value=6.000000e+00 +aten::hardtanh_ op_5 3 1 a 48 49 b +prim::Constant op_6 0 1 50 value=6.000000e+00 +aten::div op_7 2 1 b 50 c +aten::mul op_8 2 1 c input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_4, 8) + +class F_hardswish_5 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 25 value=3.000000e+00 +prim::Constant op_1 0 1 48 value=1 +aten::add op_2 3 1 input 25 48 a +aten::relu6_ op_3 1 1 a b +prim::Constant op_4 0 1 49 value=6.000000e+00 +aten::div op_5 2 1 b 49 c +aten::mul op_6 2 1 c input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_5, 8) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_hardtanh.cpp b/tools/pnnx/src/pass_level2/F_hardtanh.cpp new file mode 100644 index 000000000000..cc72c8a9529e --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_hardtanh.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_hardtanh : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 min_val +pnnx.Input input_2 0 1 max_val +aten::hardtanh op_0 3 1 input min_val max_val out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardtanh"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardtanh, 10) + +class F_hardtanh_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 min_val +pnnx.Input input_2 0 1 max_val +aten::hardtanh_ op_0 3 1 input min_val max_val out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardtanh"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardtanh_1, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_instance_norm.cpp b/tools/pnnx/src/pass_level2/F_instance_norm.cpp new file mode 100644 index 000000000000..42419a2adae2 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_instance_norm.cpp @@ -0,0 +1,48 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_instance_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +11 10 +pnnx.Input input_1 0 1 input +pnnx.Input input_2 0 1 running_mean +pnnx.Input input_3 0 1 running_var +pnnx.Input input_4 0 1 weight +pnnx.Input input_5 0 1 bias +prim::Constant op_0 0 1 use_input_stats value=True +prim::Constant op_1 0 1 momentum value=* +prim::Constant op_2 0 1 eps value=%eps +prim::Constant op_3 0 1 cudnn_enabled value=* +aten::instance_norm op_4 9 1 input weight bias running_mean running_var use_input_stats momentum eps cudnn_enabled out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.instance_norm"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_instance_norm, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_interpolate.cpp b/tools/pnnx/src/pass_level2/F_interpolate.cpp new file mode 100644 index 000000000000..8ef9eb73424d --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_interpolate.cpp @@ -0,0 +1,946 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_interpolate : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +29 28 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale value=%scale +prim::Constant op_1 0 1 5 value=2 +aten::size op_2 2 1 input 5 6 +prim::NumToTensor op_3 1 1 6 7 +prim::Constant op_4 0 1 9 value=6 +prim::Constant op_5 0 1 10 value=False +prim::Constant op_6 0 1 52 value=False +prim::Constant op_7 0 1 12 value=None +aten::to op_8 5 1 7 9 10 52 12 13 +prim::Constant op_9 0 1 51 value=* +prim::Constant op_10 0 1 53 value=6 +prim::Constant op_11 0 1 54 value=False +prim::Constant op_12 0 1 55 value=False +prim::Constant op_13 0 1 56 value=None +aten::to op_14 6 1 scale 51 53 54 55 56 20 +aten::detach op_15 1 1 20 23 +aten::mul op_16 2 1 13 23 24 +prim::Constant op_17 0 1 57 value=6 +prim::Constant op_18 0 1 58 value=False +prim::Constant op_19 0 1 59 value=False +prim::Constant op_20 0 1 60 value=None +aten::to op_21 5 1 24 57 58 59 60 28 +aten::floor op_22 1 1 28 31 +aten::Int op_23 1 1 31 33 +prim::ListConstruct op_24 1 1 33 size +prim::Constant op_25 0 1 scale_factor value=None +aten::upsample_nearest1d op_26 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = captured_params.at("scale"); + op->params["mode"] = "nearest"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate, 10) + +class F_interpolate_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +30 29 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale value=%scale +prim::Constant op_1 0 1 5 value=2 +aten::size op_2 2 1 input 5 6 +prim::NumToTensor op_3 1 1 6 7 +prim::Constant op_4 0 1 9 value=6 +prim::Constant op_5 0 1 10 value=False +prim::Constant op_6 0 1 52 value=False +prim::Constant op_7 0 1 12 value=None +aten::to op_8 5 1 7 9 10 52 12 13 +prim::Constant op_9 0 1 51 value=* +prim::Constant op_10 0 1 53 value=6 +prim::Constant op_11 0 1 54 value=False +prim::Constant op_12 0 1 55 value=False +prim::Constant op_13 0 1 56 value=None +aten::to op_14 6 1 scale 51 53 54 55 56 20 +aten::detach op_15 1 1 20 23 +aten::mul op_16 2 1 13 23 24 +prim::Constant op_17 0 1 57 value=6 +prim::Constant op_18 0 1 58 value=False +prim::Constant op_19 0 1 59 value=False +prim::Constant op_20 0 1 60 value=None +aten::to op_21 5 1 24 57 58 59 60 28 +aten::floor op_22 1 1 28 31 +aten::Int op_23 1 1 31 33 +prim::ListConstruct op_24 1 1 33 size +prim::Constant op_25 0 1 align_corners value=%align_corners +prim::Constant op_26 0 1 scale_factor value=None +aten::upsample_linear1d op_27 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = captured_params.at("scale"); + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "linear"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_1, 10) + +class F_interpolate_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +54 53 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 81 value=* +prim::Constant op_2 0 1 scale_h value=%scale_h +prim::Constant op_3 0 1 12 value=None +prim::Constant op_4 0 1 10 value=False +prim::Constant op_5 0 1 5 value=2 +prim::Constant op_6 0 1 9 value=6 +prim::Constant op_7 0 1 34 value=3 +aten::size op_8 2 1 input 5 6 +prim::NumToTensor op_9 1 1 6 7 +prim::Constant op_10 0 1 83 value=False +aten::to op_11 5 1 7 9 10 83 12 13 +prim::Constant op_12 0 1 84 value=6 +prim::Constant op_13 0 1 85 value=False +prim::Constant op_14 0 1 86 value=False +prim::Constant op_15 0 1 87 value=None +aten::to op_16 6 1 scale_h 81 84 85 86 87 20 +aten::detach op_17 1 1 20 23 +aten::mul op_18 2 1 13 23 24 +prim::Constant op_19 0 1 88 value=6 +prim::Constant op_20 0 1 89 value=False +prim::Constant op_21 0 1 90 value=False +prim::Constant op_22 0 1 91 value=None +aten::to op_23 5 1 24 88 89 90 91 28 +aten::floor op_24 1 1 28 30 +aten::Int op_25 1 1 30 32 +aten::size op_26 2 1 input 34 35 +prim::NumToTensor op_27 1 1 35 36 +prim::Constant op_28 0 1 92 value=6 +prim::Constant op_29 0 1 93 value=False +prim::Constant op_30 0 1 94 value=False +prim::Constant op_31 0 1 95 value=None +aten::to op_32 5 1 36 92 93 94 95 41 +prim::Constant op_33 0 1 96 value=* +prim::Constant op_34 0 1 97 value=6 +prim::Constant op_35 0 1 98 value=False +prim::Constant op_36 0 1 99 value=False +prim::Constant op_37 0 1 100 value=None +aten::to op_38 6 1 scale_w 96 97 98 99 100 48 +aten::detach op_39 1 1 48 51 +aten::mul op_40 2 1 41 51 52 +prim::Constant op_41 0 1 101 value=6 +prim::Constant op_42 0 1 102 value=False +prim::Constant op_43 0 1 103 value=False +prim::Constant op_44 0 1 104 value=None +aten::to op_45 5 1 52 101 102 103 104 56 +aten::floor op_46 1 1 56 60 +aten::Int op_47 1 1 60 62 +prim::ListConstruct op_48 2 1 32 62 size +prim::Constant op_49 0 1 scale_h_none value=None +prim::Constant op_50 0 1 scale_w_none value=None +aten::upsample_nearest2d op_51 4 1 input size scale_h_none scale_w_none out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["mode"] = "nearest"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_2, 10) + +class F_interpolate_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +55 54 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 82 value=* +prim::Constant op_2 0 1 scale_h value=%scale_h +prim::Constant op_3 0 1 12 value=None +prim::Constant op_4 0 1 10 value=False +prim::Constant op_5 0 1 5 value=2 +prim::Constant op_6 0 1 9 value=6 +prim::Constant op_7 0 1 34 value=3 +aten::size op_8 2 1 input 5 6 +prim::NumToTensor op_9 1 1 6 7 +prim::Constant op_10 0 1 84 value=False +aten::to op_11 5 1 7 9 10 84 12 13 +prim::Constant op_12 0 1 85 value=6 +prim::Constant op_13 0 1 86 value=False +prim::Constant op_14 0 1 87 value=False +prim::Constant op_15 0 1 88 value=None +aten::to op_16 6 1 scale_h 82 85 86 87 88 20 +aten::detach op_17 1 1 20 23 +aten::mul op_18 2 1 13 23 24 +prim::Constant op_19 0 1 89 value=6 +prim::Constant op_20 0 1 90 value=False +prim::Constant op_21 0 1 91 value=False +prim::Constant op_22 0 1 92 value=None +aten::to op_23 5 1 24 89 90 91 92 28 +aten::floor op_24 1 1 28 30 +aten::Int op_25 1 1 30 32 +aten::size op_26 2 1 input 34 35 +prim::NumToTensor op_27 1 1 35 36 +prim::Constant op_28 0 1 93 value=6 +prim::Constant op_29 0 1 94 value=False +prim::Constant op_30 0 1 95 value=False +prim::Constant op_31 0 1 96 value=None +aten::to op_32 5 1 36 93 94 95 96 41 +prim::Constant op_33 0 1 97 value=* +prim::Constant op_34 0 1 98 value=6 +prim::Constant op_35 0 1 99 value=False +prim::Constant op_36 0 1 100 value=False +prim::Constant op_37 0 1 101 value=None +aten::to op_38 6 1 scale_w 97 98 99 100 101 48 +aten::detach op_39 1 1 48 51 +aten::mul op_40 2 1 41 51 52 +prim::Constant op_41 0 1 102 value=6 +prim::Constant op_42 0 1 103 value=False +prim::Constant op_43 0 1 104 value=False +prim::Constant op_44 0 1 105 value=None +aten::to op_45 5 1 52 102 103 104 105 56 +aten::floor op_46 1 1 56 60 +aten::Int op_47 1 1 60 62 +prim::ListConstruct op_48 2 1 32 62 size +prim::Constant op_49 0 1 align_corners value=%align_corners +prim::Constant op_50 0 1 scale_h_none value=None +prim::Constant op_51 0 1 scale_w_none value=None +aten::upsample_bilinear2d op_52 5 1 input size align_corners scale_h_none scale_w_none out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bilinear"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_3, 10) + +class F_interpolate_3_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +54 53 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 82 value=* +prim::Constant op_2 0 1 scale_h value=%scale_h +prim::Constant op_3 0 1 12 value=None +prim::Constant op_4 0 1 10 value=False +prim::Constant op_5 0 1 5 value=2 +prim::Constant op_6 0 1 9 value=6 +prim::Constant op_7 0 1 34 value=3 +aten::size op_8 2 1 input 5 6 +prim::NumToTensor op_9 1 1 6 7 +prim::Constant op_10 0 1 84 value=False +aten::to op_11 5 1 7 9 10 84 12 13 +prim::Constant op_12 0 1 85 value=6 +prim::Constant op_13 0 1 86 value=False +prim::Constant op_14 0 1 87 value=False +prim::Constant op_15 0 1 88 value=None +aten::to op_16 6 1 scale_h 82 85 86 87 88 20 +aten::detach op_17 1 1 20 23 +aten::mul op_18 2 1 13 23 24 +prim::Constant op_19 0 1 89 value=6 +prim::Constant op_20 0 1 90 value=False +prim::Constant op_21 0 1 91 value=False +prim::Constant op_22 0 1 92 value=None +aten::to op_23 5 1 24 89 90 91 92 28 +aten::floor op_24 1 1 28 30 +aten::Int op_25 1 1 30 32 +aten::size op_26 2 1 input 34 35 +prim::NumToTensor op_27 1 1 35 36 +prim::Constant op_28 0 1 93 value=6 +prim::Constant op_29 0 1 94 value=False +prim::Constant op_30 0 1 95 value=False +prim::Constant op_31 0 1 96 value=None +aten::to op_32 5 1 36 93 94 95 96 41 +prim::Constant op_33 0 1 97 value=* +prim::Constant op_34 0 1 98 value=6 +prim::Constant op_35 0 1 99 value=False +prim::Constant op_36 0 1 100 value=False +prim::Constant op_37 0 1 101 value=None +aten::to op_38 6 1 scale_w 97 98 99 100 101 48 +aten::detach op_39 1 1 48 51 +aten::mul op_40 2 1 41 51 52 +prim::Constant op_41 0 1 102 value=6 +prim::Constant op_42 0 1 103 value=False +prim::Constant op_43 0 1 104 value=False +prim::Constant op_44 0 1 105 value=None +aten::to op_45 5 1 52 102 103 104 105 56 +aten::floor op_46 1 1 56 60 +aten::Int op_47 1 1 60 62 +prim::ListConstruct op_48 2 1 32 62 size +prim::Constant op_49 0 1 align_corners value=%align_corners +prim::Constant op_50 0 1 scale_factor value=None +aten::upsample_bilinear2d op_51 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bilinear"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_3_1, 10) + +class F_interpolate_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +55 54 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 82 value=* +prim::Constant op_2 0 1 scale_h value=%scale_h +prim::Constant op_3 0 1 12 value=None +prim::Constant op_4 0 1 10 value=False +prim::Constant op_5 0 1 5 value=2 +prim::Constant op_6 0 1 9 value=6 +prim::Constant op_7 0 1 34 value=3 +aten::size op_8 2 1 input 5 6 +prim::NumToTensor op_9 1 1 6 7 +prim::Constant op_10 0 1 84 value=False +aten::to op_11 5 1 7 9 10 84 12 13 +prim::Constant op_12 0 1 85 value=6 +prim::Constant op_13 0 1 86 value=False +prim::Constant op_14 0 1 87 value=False +prim::Constant op_15 0 1 88 value=None +aten::to op_16 6 1 scale_h 82 85 86 87 88 20 +aten::detach op_17 1 1 20 23 +aten::mul op_18 2 1 13 23 24 +prim::Constant op_19 0 1 89 value=6 +prim::Constant op_20 0 1 90 value=False +prim::Constant op_21 0 1 91 value=False +prim::Constant op_22 0 1 92 value=None +aten::to op_23 5 1 24 89 90 91 92 28 +aten::floor op_24 1 1 28 30 +aten::Int op_25 1 1 30 32 +aten::size op_26 2 1 input 34 35 +prim::NumToTensor op_27 1 1 35 36 +prim::Constant op_28 0 1 93 value=6 +prim::Constant op_29 0 1 94 value=False +prim::Constant op_30 0 1 95 value=False +prim::Constant op_31 0 1 96 value=None +aten::to op_32 5 1 36 93 94 95 96 41 +prim::Constant op_33 0 1 97 value=* +prim::Constant op_34 0 1 98 value=6 +prim::Constant op_35 0 1 99 value=False +prim::Constant op_36 0 1 100 value=False +prim::Constant op_37 0 1 101 value=None +aten::to op_38 6 1 scale_w 97 98 99 100 101 48 +aten::detach op_39 1 1 48 51 +aten::mul op_40 2 1 41 51 52 +prim::Constant op_41 0 1 102 value=6 +prim::Constant op_42 0 1 103 value=False +prim::Constant op_43 0 1 104 value=False +prim::Constant op_44 0 1 105 value=None +aten::to op_45 5 1 52 102 103 104 105 56 +aten::floor op_46 1 1 56 60 +aten::Int op_47 1 1 60 62 +prim::ListConstruct op_48 2 1 32 62 size +prim::Constant op_49 0 1 align_corners value=%align_corners +prim::Constant op_50 0 1 scale_h_none value=None +prim::Constant op_51 0 1 scale_w_none value=None +aten::upsample_bicubic2d op_52 5 1 input size align_corners scale_h_none scale_w_none out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bicubic"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_4, 10) + +class F_interpolate_4_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +54 53 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 82 value=* +prim::Constant op_2 0 1 scale_h value=%scale_h +prim::Constant op_3 0 1 12 value=None +prim::Constant op_4 0 1 10 value=False +prim::Constant op_5 0 1 5 value=2 +prim::Constant op_6 0 1 9 value=6 +prim::Constant op_7 0 1 34 value=3 +aten::size op_8 2 1 input 5 6 +prim::NumToTensor op_9 1 1 6 7 +prim::Constant op_10 0 1 84 value=False +aten::to op_11 5 1 7 9 10 84 12 13 +prim::Constant op_12 0 1 85 value=6 +prim::Constant op_13 0 1 86 value=False +prim::Constant op_14 0 1 87 value=False +prim::Constant op_15 0 1 88 value=None +aten::to op_16 6 1 scale_h 82 85 86 87 88 20 +aten::detach op_17 1 1 20 23 +aten::mul op_18 2 1 13 23 24 +prim::Constant op_19 0 1 89 value=6 +prim::Constant op_20 0 1 90 value=False +prim::Constant op_21 0 1 91 value=False +prim::Constant op_22 0 1 92 value=None +aten::to op_23 5 1 24 89 90 91 92 28 +aten::floor op_24 1 1 28 30 +aten::Int op_25 1 1 30 32 +aten::size op_26 2 1 input 34 35 +prim::NumToTensor op_27 1 1 35 36 +prim::Constant op_28 0 1 93 value=6 +prim::Constant op_29 0 1 94 value=False +prim::Constant op_30 0 1 95 value=False +prim::Constant op_31 0 1 96 value=None +aten::to op_32 5 1 36 93 94 95 96 41 +prim::Constant op_33 0 1 97 value=* +prim::Constant op_34 0 1 98 value=6 +prim::Constant op_35 0 1 99 value=False +prim::Constant op_36 0 1 100 value=False +prim::Constant op_37 0 1 101 value=None +aten::to op_38 6 1 scale_w 97 98 99 100 101 48 +aten::detach op_39 1 1 48 51 +aten::mul op_40 2 1 41 51 52 +prim::Constant op_41 0 1 102 value=6 +prim::Constant op_42 0 1 103 value=False +prim::Constant op_43 0 1 104 value=False +prim::Constant op_44 0 1 105 value=None +aten::to op_45 5 1 52 102 103 104 105 56 +aten::floor op_46 1 1 56 60 +aten::Int op_47 1 1 60 62 +prim::ListConstruct op_48 2 1 32 62 size +prim::Constant op_49 0 1 align_corners value=%align_corners +prim::Constant op_50 0 1 scale_factor value=None +aten::upsample_bicubic2d op_51 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bicubic"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_4_1, 10) + +class F_interpolate_5 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +79 78 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 scale_h value=%scale_h +prim::Constant op_2 0 1 108 value=* +prim::Constant op_3 0 1 scale_d value=%scale_d +prim::Constant op_4 0 1 12 value=None +prim::Constant op_5 0 1 10 value=False +prim::Constant op_6 0 1 5 value=2 +prim::Constant op_7 0 1 9 value=6 +prim::Constant op_8 0 1 34 value=3 +prim::Constant op_9 0 1 62 value=4 +aten::size op_10 2 1 input 5 6 +prim::NumToTensor op_11 1 1 6 7 +prim::Constant op_12 0 1 111 value=False +aten::to op_13 5 1 7 9 10 111 12 13 +prim::Constant op_14 0 1 112 value=6 +prim::Constant op_15 0 1 113 value=False +prim::Constant op_16 0 1 114 value=False +prim::Constant op_17 0 1 115 value=None +aten::to op_18 6 1 scale_d 108 112 113 114 115 20 +aten::detach op_19 1 1 20 23 +aten::mul op_20 2 1 13 23 24 +prim::Constant op_21 0 1 116 value=6 +prim::Constant op_22 0 1 117 value=False +prim::Constant op_23 0 1 118 value=False +prim::Constant op_24 0 1 119 value=None +aten::to op_25 5 1 24 116 117 118 119 28 +aten::floor op_26 1 1 28 30 +aten::Int op_27 1 1 30 32 +aten::size op_28 2 1 input 34 35 +prim::NumToTensor op_29 1 1 35 36 +prim::Constant op_30 0 1 120 value=6 +prim::Constant op_31 0 1 121 value=False +prim::Constant op_32 0 1 122 value=False +prim::Constant op_33 0 1 123 value=None +aten::to op_34 5 1 36 120 121 122 123 41 +prim::Constant op_35 0 1 124 value=* +prim::Constant op_36 0 1 125 value=6 +prim::Constant op_37 0 1 126 value=False +prim::Constant op_38 0 1 127 value=False +prim::Constant op_39 0 1 128 value=None +aten::to op_40 6 1 scale_h 124 125 126 127 128 48 +aten::detach op_41 1 1 48 51 +aten::mul op_42 2 1 41 51 52 +prim::Constant op_43 0 1 129 value=6 +prim::Constant op_44 0 1 130 value=False +prim::Constant op_45 0 1 131 value=False +prim::Constant op_46 0 1 132 value=None +aten::to op_47 5 1 52 129 130 131 132 56 +aten::floor op_48 1 1 56 58 +aten::Int op_49 1 1 58 60 +aten::size op_50 2 1 input 62 63 +prim::NumToTensor op_51 1 1 63 64 +prim::Constant op_52 0 1 133 value=6 +prim::Constant op_53 0 1 134 value=False +prim::Constant op_54 0 1 135 value=False +prim::Constant op_55 0 1 136 value=None +aten::to op_56 5 1 64 133 134 135 136 69 +prim::Constant op_57 0 1 137 value=* +prim::Constant op_58 0 1 138 value=6 +prim::Constant op_59 0 1 139 value=False +prim::Constant op_60 0 1 140 value=False +prim::Constant op_61 0 1 141 value=None +aten::to op_62 6 1 scale_w 137 138 139 140 141 76 +aten::detach op_63 1 1 76 79 +aten::mul op_64 2 1 69 79 80 +prim::Constant op_65 0 1 142 value=6 +prim::Constant op_66 0 1 143 value=False +prim::Constant op_67 0 1 144 value=False +prim::Constant op_68 0 1 145 value=None +aten::to op_69 5 1 80 142 143 144 145 84 +aten::floor op_70 1 1 84 89 +aten::Int op_71 1 1 89 91 +prim::ListConstruct op_72 3 1 32 60 91 size +prim::Constant op_73 0 1 scale_d_none value=None +prim::Constant op_74 0 1 scale_h_none value=None +prim::Constant op_75 0 1 scale_w_none value=None +aten::upsample_nearest3d op_76 5 1 input size scale_d_none scale_h_none scale_w_none out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_d").f, captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["mode"] = "nearest"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_5, 10) + +class F_interpolate_5_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +77 76 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 scale_h value=%scale_h +prim::Constant op_2 0 1 108 value=* +prim::Constant op_3 0 1 scale_d value=%scale_d +prim::Constant op_4 0 1 12 value=None +prim::Constant op_5 0 1 10 value=False +prim::Constant op_6 0 1 5 value=2 +prim::Constant op_7 0 1 9 value=6 +prim::Constant op_8 0 1 34 value=3 +prim::Constant op_9 0 1 62 value=4 +aten::size op_10 2 1 input 5 6 +prim::NumToTensor op_11 1 1 6 7 +prim::Constant op_12 0 1 111 value=False +aten::to op_13 5 1 7 9 10 111 12 13 +prim::Constant op_14 0 1 112 value=6 +prim::Constant op_15 0 1 113 value=False +prim::Constant op_16 0 1 114 value=False +prim::Constant op_17 0 1 115 value=None +aten::to op_18 6 1 scale_d 108 112 113 114 115 20 +aten::detach op_19 1 1 20 23 +aten::mul op_20 2 1 13 23 24 +prim::Constant op_21 0 1 116 value=6 +prim::Constant op_22 0 1 117 value=False +prim::Constant op_23 0 1 118 value=False +prim::Constant op_24 0 1 119 value=None +aten::to op_25 5 1 24 116 117 118 119 28 +aten::floor op_26 1 1 28 30 +aten::Int op_27 1 1 30 32 +aten::size op_28 2 1 input 34 35 +prim::NumToTensor op_29 1 1 35 36 +prim::Constant op_30 0 1 120 value=6 +prim::Constant op_31 0 1 121 value=False +prim::Constant op_32 0 1 122 value=False +prim::Constant op_33 0 1 123 value=None +aten::to op_34 5 1 36 120 121 122 123 41 +prim::Constant op_35 0 1 124 value=* +prim::Constant op_36 0 1 125 value=6 +prim::Constant op_37 0 1 126 value=False +prim::Constant op_38 0 1 127 value=False +prim::Constant op_39 0 1 128 value=None +aten::to op_40 6 1 scale_h 124 125 126 127 128 48 +aten::detach op_41 1 1 48 51 +aten::mul op_42 2 1 41 51 52 +prim::Constant op_43 0 1 129 value=6 +prim::Constant op_44 0 1 130 value=False +prim::Constant op_45 0 1 131 value=False +prim::Constant op_46 0 1 132 value=None +aten::to op_47 5 1 52 129 130 131 132 56 +aten::floor op_48 1 1 56 58 +aten::Int op_49 1 1 58 60 +aten::size op_50 2 1 input 62 63 +prim::NumToTensor op_51 1 1 63 64 +prim::Constant op_52 0 1 133 value=6 +prim::Constant op_53 0 1 134 value=False +prim::Constant op_54 0 1 135 value=False +prim::Constant op_55 0 1 136 value=None +aten::to op_56 5 1 64 133 134 135 136 69 +prim::Constant op_57 0 1 137 value=* +prim::Constant op_58 0 1 138 value=6 +prim::Constant op_59 0 1 139 value=False +prim::Constant op_60 0 1 140 value=False +prim::Constant op_61 0 1 141 value=None +aten::to op_62 6 1 scale_w 137 138 139 140 141 76 +aten::detach op_63 1 1 76 79 +aten::mul op_64 2 1 69 79 80 +prim::Constant op_65 0 1 142 value=6 +prim::Constant op_66 0 1 143 value=False +prim::Constant op_67 0 1 144 value=False +prim::Constant op_68 0 1 145 value=None +aten::to op_69 5 1 80 142 143 144 145 84 +aten::floor op_70 1 1 84 89 +aten::Int op_71 1 1 89 91 +prim::ListConstruct op_72 3 1 32 60 91 size +prim::Constant op_73 0 1 scale_factor value=None +aten::upsample_nearest3d op_74 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_d").f, captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["mode"] = "nearest"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_5_1, 10) + +class F_interpolate_6 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +80 79 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 scale_h value=%scale_h +prim::Constant op_2 0 1 113 value=* +prim::Constant op_3 0 1 scale_d value=%scale_d +prim::Constant op_4 0 1 12 value=None +prim::Constant op_5 0 1 10 value=False +prim::Constant op_6 0 1 5 value=2 +prim::Constant op_7 0 1 9 value=6 +prim::Constant op_8 0 1 34 value=3 +prim::Constant op_9 0 1 62 value=4 +aten::size op_10 2 1 input 5 6 +prim::NumToTensor op_11 1 1 6 7 +prim::Constant op_12 0 1 116 value=False +aten::to op_13 5 1 7 9 10 116 12 13 +prim::Constant op_14 0 1 117 value=6 +prim::Constant op_15 0 1 118 value=False +prim::Constant op_16 0 1 119 value=False +prim::Constant op_17 0 1 120 value=None +aten::to op_18 6 1 scale_d 113 117 118 119 120 20 +aten::detach op_19 1 1 20 23 +aten::mul op_20 2 1 13 23 24 +prim::Constant op_21 0 1 121 value=6 +prim::Constant op_22 0 1 122 value=False +prim::Constant op_23 0 1 123 value=False +prim::Constant op_24 0 1 124 value=None +aten::to op_25 5 1 24 121 122 123 124 28 +aten::floor op_26 1 1 28 30 +aten::Int op_27 1 1 30 32 +aten::size op_28 2 1 input 34 35 +prim::NumToTensor op_29 1 1 35 36 +prim::Constant op_30 0 1 125 value=6 +prim::Constant op_31 0 1 126 value=False +prim::Constant op_32 0 1 127 value=False +prim::Constant op_33 0 1 128 value=None +aten::to op_34 5 1 36 125 126 127 128 41 +prim::Constant op_35 0 1 129 value=* +prim::Constant op_36 0 1 130 value=6 +prim::Constant op_37 0 1 131 value=False +prim::Constant op_38 0 1 132 value=False +prim::Constant op_39 0 1 133 value=None +aten::to op_40 6 1 scale_h 129 130 131 132 133 48 +aten::detach op_41 1 1 48 51 +aten::mul op_42 2 1 41 51 52 +prim::Constant op_43 0 1 134 value=6 +prim::Constant op_44 0 1 135 value=False +prim::Constant op_45 0 1 136 value=False +prim::Constant op_46 0 1 137 value=None +aten::to op_47 5 1 52 134 135 136 137 56 +aten::floor op_48 1 1 56 58 +aten::Int op_49 1 1 58 60 +aten::size op_50 2 1 input 62 63 +prim::NumToTensor op_51 1 1 63 64 +prim::Constant op_52 0 1 138 value=6 +prim::Constant op_53 0 1 139 value=False +prim::Constant op_54 0 1 140 value=False +prim::Constant op_55 0 1 141 value=None +aten::to op_56 5 1 64 138 139 140 141 69 +prim::Constant op_57 0 1 142 value=* +prim::Constant op_58 0 1 143 value=6 +prim::Constant op_59 0 1 144 value=False +prim::Constant op_60 0 1 145 value=False +prim::Constant op_61 0 1 146 value=None +aten::to op_62 6 1 scale_w 142 143 144 145 146 76 +aten::detach op_63 1 1 76 79 +aten::mul op_64 2 1 69 79 80 +prim::Constant op_65 0 1 147 value=6 +prim::Constant op_66 0 1 148 value=False +prim::Constant op_67 0 1 149 value=False +prim::Constant op_68 0 1 150 value=None +aten::to op_69 5 1 80 147 148 149 150 84 +aten::floor op_70 1 1 84 89 +aten::Int op_71 1 1 89 91 +prim::ListConstruct op_72 3 1 32 60 91 size +prim::Constant op_73 0 1 align_corners value=%align_corners +prim::Constant op_74 0 1 scale_d_none value=None +prim::Constant op_75 0 1 scale_h_none value=None +prim::Constant op_76 0 1 scale_w_none value=None +aten::upsample_trilinear3d op_77 6 1 input size align_corners scale_d_none scale_h_none scale_w_none out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_d").f, captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "trilinear"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_6, 10) + +class F_interpolate_6_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +78 77 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 scale_w value=%scale_w +prim::Constant op_1 0 1 scale_h value=%scale_h +prim::Constant op_2 0 1 113 value=* +prim::Constant op_3 0 1 scale_d value=%scale_d +prim::Constant op_4 0 1 12 value=None +prim::Constant op_5 0 1 10 value=False +prim::Constant op_6 0 1 5 value=2 +prim::Constant op_7 0 1 9 value=6 +prim::Constant op_8 0 1 34 value=3 +prim::Constant op_9 0 1 62 value=4 +aten::size op_10 2 1 input 5 6 +prim::NumToTensor op_11 1 1 6 7 +prim::Constant op_12 0 1 116 value=False +aten::to op_13 5 1 7 9 10 116 12 13 +prim::Constant op_14 0 1 117 value=6 +prim::Constant op_15 0 1 118 value=False +prim::Constant op_16 0 1 119 value=False +prim::Constant op_17 0 1 120 value=None +aten::to op_18 6 1 scale_d 113 117 118 119 120 20 +aten::detach op_19 1 1 20 23 +aten::mul op_20 2 1 13 23 24 +prim::Constant op_21 0 1 121 value=6 +prim::Constant op_22 0 1 122 value=False +prim::Constant op_23 0 1 123 value=False +prim::Constant op_24 0 1 124 value=None +aten::to op_25 5 1 24 121 122 123 124 28 +aten::floor op_26 1 1 28 30 +aten::Int op_27 1 1 30 32 +aten::size op_28 2 1 input 34 35 +prim::NumToTensor op_29 1 1 35 36 +prim::Constant op_30 0 1 125 value=6 +prim::Constant op_31 0 1 126 value=False +prim::Constant op_32 0 1 127 value=False +prim::Constant op_33 0 1 128 value=None +aten::to op_34 5 1 36 125 126 127 128 41 +prim::Constant op_35 0 1 129 value=* +prim::Constant op_36 0 1 130 value=6 +prim::Constant op_37 0 1 131 value=False +prim::Constant op_38 0 1 132 value=False +prim::Constant op_39 0 1 133 value=None +aten::to op_40 6 1 scale_h 129 130 131 132 133 48 +aten::detach op_41 1 1 48 51 +aten::mul op_42 2 1 41 51 52 +prim::Constant op_43 0 1 134 value=6 +prim::Constant op_44 0 1 135 value=False +prim::Constant op_45 0 1 136 value=False +prim::Constant op_46 0 1 137 value=None +aten::to op_47 5 1 52 134 135 136 137 56 +aten::floor op_48 1 1 56 58 +aten::Int op_49 1 1 58 60 +aten::size op_50 2 1 input 62 63 +prim::NumToTensor op_51 1 1 63 64 +prim::Constant op_52 0 1 138 value=6 +prim::Constant op_53 0 1 139 value=False +prim::Constant op_54 0 1 140 value=False +prim::Constant op_55 0 1 141 value=None +aten::to op_56 5 1 64 138 139 140 141 69 +prim::Constant op_57 0 1 142 value=* +prim::Constant op_58 0 1 143 value=6 +prim::Constant op_59 0 1 144 value=False +prim::Constant op_60 0 1 145 value=False +prim::Constant op_61 0 1 146 value=None +aten::to op_62 6 1 scale_w 142 143 144 145 146 76 +aten::detach op_63 1 1 76 79 +aten::mul op_64 2 1 69 79 80 +prim::Constant op_65 0 1 147 value=6 +prim::Constant op_66 0 1 148 value=False +prim::Constant op_67 0 1 149 value=False +prim::Constant op_68 0 1 150 value=None +aten::to op_69 5 1 80 147 148 149 150 84 +aten::floor op_70 1 1 84 89 +aten::Int op_71 1 1 89 91 +prim::ListConstruct op_72 3 1 32 60 91 size +prim::Constant op_73 0 1 align_corners value=%align_corners +prim::Constant op_74 0 1 scale_factor value=None +aten::upsample_trilinear3d op_75 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["scale_factor"] = Parameter{captured_params.at("scale_d").f, captured_params.at("scale_h").f, captured_params.at("scale_w").f}; + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "trilinear"; + op->params["recompute_scale_factor"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_interpolate_6_1, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_layer_norm.cpp b/tools/pnnx/src/pass_level2/F_layer_norm.cpp new file mode 100644 index 000000000000..ff914e2ca3f3 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_layer_norm.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_layer_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +pnnx.Input input_3 0 1 normalized_shape +prim::Constant op_0 0 1 eps value=%eps +prim::Constant op_1 0 1 cudnn_enabled value=* +aten::layer_norm op_2 6 1 input normalized_shape weight bias eps cudnn_enabled out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.layer_norm"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_layer_norm, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_leaky_relu.cpp b/tools/pnnx/src/pass_level2/F_leaky_relu.cpp new file mode 100644 index 000000000000..8fefe33c7004 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_leaky_relu.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_leaky_relu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 negative_slope +aten::leaky_relu op_0 2 1 input negative_slope out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.leaky_relu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_leaky_relu, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_linear.cpp b/tools/pnnx/src/pass_level2/F_linear.cpp new file mode 100644 index 000000000000..a5931bb3733d --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_linear.cpp @@ -0,0 +1,122 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_linear : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +aten::linear op_0 3 1 input weight bias out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.linear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear, 10) + +class F_linear_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +aten::t op_0 1 1 weight 9 +aten::matmul op_1 2 1 input 9 a +prim::Constant op_2 0 1 19 value=1 +aten::add_ op_3 3 1 a bias 19 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.linear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_1, 9) + +class F_linear_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +aten::t op_0 1 1 weight 9 +aten::matmul op_1 2 1 input 9 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.linear"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["bias"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_2, 10) + +class F_linear_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +aten::t op_0 1 1 weight 14 +prim::Constant op_1 0 1 15 value=1 +prim::Constant op_2 0 1 30 value=1 +aten::addmm op_3 5 1 bias input 14 15 30 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.linear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_3, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_local_response_norm.cpp b/tools/pnnx/src/pass_level2/F_local_response_norm.cpp new file mode 100644 index 000000000000..55af608e62c3 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_local_response_norm.cpp @@ -0,0 +1,299 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_local_response_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +35 34 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 k value=%k +prim::Constant op_1 0 1 alpha value=%alpha +prim::Constant op_2 0 1 24 value=None +prim::Constant op_3 0 1 23 value=True +prim::Constant op_4 0 1 22 value=False +prim::Constant op_5 0 1 7 value=1 +prim::Constant op_6 0 1 10 value=0 +prim::Constant op_7 0 1 size value=%size +prim::Constant op_8 0 1 beta value=%beta +aten::mul op_9 2 1 input input 6 +aten::unsqueeze op_10 2 1 6 7 input.1 +prim::Constant op_11 0 1 52 value=0 +prim::Constant op_12 0 1 53 value=* +prim::Constant op_13 0 1 54 value=* +prim::ListConstruct op_14 4 1 10 52 53 54 11 +prim::Constant op_15 0 1 55 value=%padzero +aten::constant_pad_nd op_16 3 1 input.1 11 55 div.1 +prim::Constant op_17 0 1 56 value=1 +prim::ListConstruct op_18 2 1 size 56 16 +prim::Constant op_19 0 1 57 value=1 +prim::Constant op_20 0 1 58 value=1 +prim::ListConstruct op_21 2 1 57 58 17 +prim::Constant op_22 0 1 59 value=0 +prim::Constant op_23 0 1 60 value=0 +prim::ListConstruct op_24 2 1 59 60 18 +aten::avg_pool2d op_25 7 1 div.1 16 17 18 22 23 24 25 +prim::Constant op_26 0 1 61 value=1 +aten::squeeze op_27 2 1 25 61 div0.1 +aten::mul op_28 2 1 div0.1 alpha 30 +prim::Constant op_29 0 1 62 value=1 +aten::add op_30 3 1 30 k 62 33 +aten::pow op_31 2 1 33 beta div1.1 +aten::div op_32 2 1 input div1.1 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.local_response_norm"; + } + + bool match_captured_params(const std::map& captured_params) const + { + if (captured_params.at("padzero").type == 2) + return captured_params.at("padzero").i == 0; + + if (captured_params.at("padzero").type == 3) + return captured_params.at("padzero").f == 0.f; + + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["size"] = captured_params.at("size"); + op->params["alpha"] = captured_params.at("alpha"); + op->params["beta"] = captured_params.at("beta"); + op->params["k"] = captured_params.at("k"); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_local_response_norm, 8) + +class F_local_response_norm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +65 64 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 k value=%k +prim::Constant op_1 0 1 alpha value=%alpha +prim::Constant op_2 0 1 66 value=None +prim::Constant op_3 0 1 65 value=True +prim::Constant op_4 0 1 64 value=False +prim::Constant op_5 0 1 7 value=1 +prim::Constant op_6 0 1 10 value=0 +prim::Constant op_7 0 1 29 value=2 +prim::Constant op_8 0 1 39 value=3 +prim::Constant op_9 0 1 49 value=-1 +prim::Constant op_10 0 1 size value=%size +prim::Constant op_11 0 1 beta value=%beta +aten::mul op_12 2 1 input input 6 +aten::unsqueeze op_13 2 1 6 7 div.1 +aten::size op_14 2 1 input 10 11 +prim::NumToTensor op_15 1 1 11 12 +aten::Int op_16 1 1 12 15 +aten::Int op_17 1 1 12 18 +prim::Constant op_18 0 1 101 value=1 +aten::size op_19 2 1 input 101 20 +prim::NumToTensor op_20 1 1 20 21 +aten::Int op_21 1 1 21 24 +aten::Int op_22 1 1 21 27 +aten::size op_23 2 1 input 29 30 +prim::NumToTensor op_24 1 1 30 31 +aten::Int op_25 1 1 31 34 +aten::Int op_26 1 1 31 37 +aten::size op_27 2 1 input 39 40 +prim::NumToTensor op_28 1 1 40 41 +aten::Int op_29 1 1 41 44 +prim::Constant op_30 0 1 102 value=1 +prim::ListConstruct op_31 5 1 18 102 27 37 49 50 +aten::view op_32 2 1 div.1 50 input.1 +prim::Constant op_33 0 1 103 value=0 +prim::Constant op_34 0 1 104 value=0 +prim::Constant op_35 0 1 105 value=0 +prim::Constant op_36 0 1 106 value=0 +prim::Constant op_37 0 1 107 value=* +prim::Constant op_38 0 1 108 value=* +prim::ListConstruct op_39 6 1 103 104 105 106 107 108 53 +prim::Constant op_40 0 1 109 value=%padzero +aten::constant_pad_nd op_41 3 1 input.1 53 109 div0.1 +prim::Constant op_42 0 1 110 value=1 +prim::Constant op_43 0 1 111 value=1 +prim::ListConstruct op_44 3 1 size 110 111 58 +prim::Constant op_45 0 1 112 value=1 +prim::Constant op_46 0 1 113 value=1 +prim::Constant op_47 0 1 114 value=1 +prim::ListConstruct op_48 3 1 112 113 114 59 +prim::Constant op_49 0 1 115 value=0 +prim::Constant op_50 0 1 116 value=0 +prim::Constant op_51 0 1 117 value=0 +prim::ListConstruct op_52 3 1 115 116 117 60 +aten::avg_pool3d op_53 7 1 div0.1 58 59 60 64 65 66 67 +prim::Constant op_54 0 1 118 value=1 +aten::squeeze op_55 2 1 67 118 div1.1 +prim::ListConstruct op_56 4 1 15 24 34 44 75 +aten::view op_57 2 1 div1.1 75 div2.1 +aten::mul op_58 2 1 div2.1 alpha 79 +prim::Constant op_59 0 1 119 value=1 +aten::add op_60 3 1 79 k 119 82 +aten::pow op_61 2 1 82 beta div3.1 +aten::div op_62 2 1 input div3.1 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.local_response_norm"; + } + + bool match_captured_params(const std::map& captured_params) const + { + if (captured_params.at("padzero").type == 2) + return captured_params.at("padzero").i == 0; + + if (captured_params.at("padzero").type == 3) + return captured_params.at("padzero").f == 0.f; + + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["size"] = captured_params.at("size"); + op->params["alpha"] = captured_params.at("alpha"); + op->params["beta"] = captured_params.at("beta"); + op->params["k"] = captured_params.at("k"); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_local_response_norm_1, 8) + +class F_local_response_norm_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +69 68 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 k value=%k +prim::Constant op_1 0 1 alpha value=%alpha +prim::Constant op_2 0 1 73 value=None +prim::Constant op_3 0 1 72 value=True +prim::Constant op_4 0 1 71 value=False +prim::Constant op_5 0 1 7 value=1 +prim::Constant op_6 0 1 10 value=0 +prim::Constant op_7 0 1 29 value=2 +prim::Constant op_8 0 1 39 value=3 +prim::Constant op_9 0 1 46 value=4 +prim::Constant op_10 0 1 56 value=-1 +prim::Constant op_11 0 1 size value=%size +prim::Constant op_12 0 1 beta value=%beta +aten::mul op_13 2 1 input input 6 +aten::unsqueeze op_14 2 1 6 7 div.1 +aten::size op_15 2 1 input 10 11 +prim::NumToTensor op_16 1 1 11 12 +aten::Int op_17 1 1 12 15 +aten::Int op_18 1 1 12 18 +prim::Constant op_19 0 1 109 value=1 +aten::size op_20 2 1 input 109 20 +prim::NumToTensor op_21 1 1 20 21 +aten::Int op_22 1 1 21 24 +aten::Int op_23 1 1 21 27 +aten::size op_24 2 1 input 29 30 +prim::NumToTensor op_25 1 1 30 31 +aten::Int op_26 1 1 31 34 +aten::Int op_27 1 1 31 37 +aten::size op_28 2 1 input 39 40 +prim::NumToTensor op_29 1 1 40 41 +aten::Int op_30 1 1 41 44 +aten::size op_31 2 1 input 46 47 +prim::NumToTensor op_32 1 1 47 48 +aten::Int op_33 1 1 48 51 +prim::Constant op_34 0 1 110 value=1 +prim::ListConstruct op_35 5 1 18 110 27 37 56 57 +aten::view op_36 2 1 div.1 57 input.1 +prim::Constant op_37 0 1 111 value=0 +prim::Constant op_38 0 1 112 value=0 +prim::Constant op_39 0 1 113 value=0 +prim::Constant op_40 0 1 114 value=0 +prim::Constant op_41 0 1 115 value=* +prim::Constant op_42 0 1 116 value=* +prim::ListConstruct op_43 6 1 111 112 113 114 115 116 60 +prim::Constant op_44 0 1 117 value=%padzero +aten::constant_pad_nd op_45 3 1 input.1 60 117 div0.1 +prim::Constant op_46 0 1 118 value=1 +prim::Constant op_47 0 1 119 value=1 +prim::ListConstruct op_48 3 1 size 118 119 65 +prim::Constant op_49 0 1 120 value=1 +prim::Constant op_50 0 1 121 value=1 +prim::Constant op_51 0 1 122 value=1 +prim::ListConstruct op_52 3 1 120 121 122 66 +prim::Constant op_53 0 1 123 value=0 +prim::Constant op_54 0 1 124 value=0 +prim::Constant op_55 0 1 125 value=0 +prim::ListConstruct op_56 3 1 123 124 125 67 +aten::avg_pool3d op_57 7 1 div0.1 65 66 67 71 72 73 74 +prim::Constant op_58 0 1 126 value=1 +aten::squeeze op_59 2 1 74 126 div1.1 +prim::ListConstruct op_60 5 1 15 24 34 44 51 83 +aten::view op_61 2 1 div1.1 83 div2.1 +aten::mul op_62 2 1 div2.1 alpha 87 +prim::Constant op_63 0 1 127 value=1 +aten::add op_64 3 1 87 k 127 90 +aten::pow op_65 2 1 90 beta div3.1 +aten::div op_66 2 1 input div3.1 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.local_response_norm"; + } + + bool match_captured_params(const std::map& captured_params) const + { + if (captured_params.at("padzero").type == 2) + return captured_params.at("padzero").i == 0; + + if (captured_params.at("padzero").type == 3) + return captured_params.at("padzero").f == 0.f; + + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["size"] = captured_params.at("size"); + op->params["alpha"] = captured_params.at("alpha"); + op->params["beta"] = captured_params.at("beta"); + op->params["k"] = captured_params.at("k"); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_local_response_norm_2, 8) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_log_softmax.cpp b/tools/pnnx/src/pass_level2/F_log_softmax.cpp new file mode 100644 index 000000000000..dd44a1c06c89 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_log_softmax.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_log_softmax : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 dtype value=None +aten::log_softmax op_1 3 1 input dim dtype out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.log_softmax"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_log_softmax, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_logsigmoid.cpp b/tools/pnnx/src/pass_level2/F_logsigmoid.cpp new file mode 100644 index 000000000000..e35670686a0e --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_logsigmoid.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_logsigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::log_sigmoid op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.logsigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_logsigmoid, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_lp_pool1d.cpp b/tools/pnnx/src/pass_level2/F_lp_pool1d.cpp new file mode 100644 index 000000000000..e88b1cbff4a5 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_lp_pool1d.cpp @@ -0,0 +1,94 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_lp_pool1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +20 19 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 norm_type +prim::ListConstruct op_0 1 1 kernel_size kernel_size_tuple +aten::pow op_1 2 1 input norm_type 4 +prim::Constant op_2 0 1 padding_w value=0 +prim::ListConstruct op_3 1 1 padding_w padding +prim::Constant op_4 0 1 ceil_mode value=%ceil_mode +prim::Constant op_5 0 1 count_include_pad value=True +aten::avg_pool1d op_6 6 1 4 kernel_size_tuple stride padding ceil_mode count_include_pad out.1 +aten::sign op_7 1 1 out.1 14 +aten::abs op_8 1 1 out.1 input.1 +aten::relu op_9 1 1 input.1 19 +aten::mul op_10 2 1 14 19 20 +prim::Constant op_11 0 1 21 value=* +aten::mul op_12 2 1 20 21 22 +prim::Constant op_13 0 1 24 value=* +aten::pow op_14 2 1 22 24 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.lp_pool1d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_lp_pool1d, 7) + +class F_lp_pool1d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +19 18 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 norm_type +aten::pow op_0 2 1 input norm_type 4 +prim::Constant op_1 0 1 padding_w value=0 +prim::ListConstruct op_2 1 1 padding_w padding +prim::Constant op_3 0 1 ceil_mode value=%ceil_mode +prim::Constant op_4 0 1 count_include_pad value=True +aten::avg_pool1d op_5 6 1 4 kernel_size stride padding ceil_mode count_include_pad out.1 +aten::sign op_6 1 1 out.1 14 +aten::abs op_7 1 1 out.1 input.1 +aten::relu op_8 1 1 input.1 19 +aten::mul op_9 2 1 14 19 20 +prim::Constant op_10 0 1 21 value=* +aten::mul op_11 2 1 20 21 22 +prim::Constant op_12 0 1 24 value=* +aten::pow op_13 2 1 22 24 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.lp_pool1d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_lp_pool1d_1, 8) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_lp_pool2d.cpp b/tools/pnnx/src/pass_level2/F_lp_pool2d.cpp new file mode 100644 index 000000000000..4f2ed1f4a956 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_lp_pool2d.cpp @@ -0,0 +1,58 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_lp_pool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +21 20 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 norm_type +aten::pow op_0 2 1 input norm_type 4 +prim::Constant op_1 0 1 padding_h value=0 +prim::Constant op_2 0 1 padding_w value=0 +prim::ListConstruct op_3 2 1 padding_h padding_w padding +prim::Constant op_4 0 1 ceil_mode value=%ceil_mode +prim::Constant op_5 0 1 count_include_pad value=True +prim::Constant op_6 0 1 divisor_override value=None +aten::avg_pool2d op_7 7 1 4 kernel_size stride padding ceil_mode count_include_pad divisor_override out.1 +aten::sign op_8 1 1 out.1 14 +aten::abs op_9 1 1 out.1 input.1 +aten::relu op_10 1 1 input.1 19 +aten::mul op_11 2 1 14 19 20 +prim::Constant op_12 0 1 21 value=* +aten::mul op_13 2 1 20 21 22 +prim::Constant op_14 0 1 24 value=* +aten::pow op_15 2 1 22 24 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.lp_pool2d"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_lp_pool2d, 8) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_max_pool1d.cpp b/tools/pnnx/src/pass_level2/F_max_pool1d.cpp new file mode 100644 index 000000000000..77d4d91feb51 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_max_pool1d.cpp @@ -0,0 +1,83 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_max_pool1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +pnnx.Input input_4 0 1 dilation +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +aten::max_pool1d op_1 6 1 input kernel_size stride padding dilation ceil_mode out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool1d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["ceil_mode"] = captured_params.at("ceil_mode"); + op->params["return_indices"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool1d, 10) + +class F_max_pool1d_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +pnnx.Input input_4 0 1 dilation +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +aten::max_pool1d_with_indices op_1 6 2 input kernel_size stride padding dilation ceil_mode out indices +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool1d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["ceil_mode"] = captured_params.at("ceil_mode"); + op->params["return_indices"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool1d_2, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_max_pool2d.cpp b/tools/pnnx/src/pass_level2/F_max_pool2d.cpp new file mode 100644 index 000000000000..4b08f577a8bf --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_max_pool2d.cpp @@ -0,0 +1,83 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_max_pool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +pnnx.Input input_4 0 1 dilation +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +aten::max_pool2d op_1 6 1 input kernel_size stride padding dilation ceil_mode out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["ceil_mode"] = captured_params.at("ceil_mode"); + op->params["return_indices"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d, 10) + +class F_max_pool2d_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +pnnx.Input input_4 0 1 dilation +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +aten::max_pool2d_with_indices op_1 6 2 input kernel_size stride padding dilation ceil_mode out indices +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["ceil_mode"] = captured_params.at("ceil_mode"); + op->params["return_indices"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_2, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_max_pool3d.cpp b/tools/pnnx/src/pass_level2/F_max_pool3d.cpp new file mode 100644 index 000000000000..afbaa615ca25 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_max_pool3d.cpp @@ -0,0 +1,83 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_max_pool3d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +pnnx.Input input_4 0 1 dilation +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +aten::max_pool3d op_1 6 1 input kernel_size stride padding dilation ceil_mode out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool3d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["ceil_mode"] = captured_params.at("ceil_mode"); + op->params["return_indices"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool3d, 10) + +class F_max_pool3d_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 8 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 kernel_size +pnnx.Input input_2 0 1 stride +pnnx.Input input_3 0 1 padding +pnnx.Input input_4 0 1 dilation +prim::Constant op_0 0 1 ceil_mode value=%ceil_mode +aten::max_pool3d_with_indices op_1 6 2 input kernel_size stride padding dilation ceil_mode out indices +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool3d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["ceil_mode"] = captured_params.at("ceil_mode"); + op->params["return_indices"] = true; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool3d_2, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_mish.cpp b/tools/pnnx/src/pass_level2/F_mish.cpp new file mode 100644 index 000000000000..1a083ba85d9a --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_mish.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_mish : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::mish op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.mish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_mish, 10) + +class F_mish_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 11 value=1 +prim::Constant op_1 0 1 12 value=20 +aten::softplus op_2 3 1 input 11 12 a +aten::tanh op_3 1 1 a b +aten::mul op_4 2 1 input b out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.mish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_mish_1, 9) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_normalize.cpp b/tools/pnnx/src/pass_level2/F_normalize.cpp new file mode 100644 index 000000000000..9717282dd0f6 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_normalize.cpp @@ -0,0 +1,48 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_normalize : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +11 10 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 keepdim value=True +prim::Constant op_1 0 1 p value=%p +prim::Constant op_2 0 1 dim value=%dim +prim::Constant op_3 0 1 eps value=%eps +prim::ListConstruct op_4 1 1 dim dims +aten::norm op_5 4 1 input p dims keepdim 9 +aten::clamp_min op_6 2 1 9 eps 11 +aten::expand_as op_7 2 1 11 input denorm +aten::div op_8 2 1 input denorm out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.normalize"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_normalize, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_pad.cpp b/tools/pnnx/src/pass_level2/F_pad.cpp new file mode 100644 index 000000000000..cccae31a9159 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_pad.cpp @@ -0,0 +1,182 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_pad : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 pad +pnnx.Input input_2 0 1 value +aten::constant_pad_nd op_0 3 1 input pad value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "constant"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad, 10) + +class F_pad_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 pad +aten::reflection_pad1d op_0 2 1 input pad out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "reflect"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_1, 10) + +class F_pad_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 pad +aten::replication_pad1d op_0 2 1 input pad out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "replicate"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_2, 10) + +class F_pad_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 pad +aten::reflection_pad2d op_0 2 1 input pad out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "reflect"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_3, 10) + +class F_pad_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 pad +aten::replication_pad2d op_0 2 1 input pad out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "replicate"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_4, 10) + +class F_pad_6 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 pad +aten::replication_pad3d op_0 2 1 input pad out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "replicate"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_6, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_pixel_shuffle.cpp b/tools/pnnx/src/pass_level2/F_pixel_shuffle.cpp new file mode 100644 index 000000000000..ec1493ba561d --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_pixel_shuffle.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_pixel_shuffle : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 upscale_factor +aten::pixel_shuffle op_0 2 1 input upscale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pixel_shuffle"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pixel_shuffle, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_pixel_unshuffle.cpp b/tools/pnnx/src/pass_level2/F_pixel_unshuffle.cpp new file mode 100644 index 000000000000..38682a8c9d04 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_pixel_unshuffle.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_pixel_unshuffle : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 downscale_factor +aten::pixel_unshuffle op_0 2 1 input downscale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pixel_unshuffle"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pixel_unshuffle, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_prelu.cpp b/tools/pnnx/src/pass_level2/F_prelu.cpp new file mode 100644 index 000000000000..96fe71417801 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_prelu.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_prelu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +aten::prelu op_0 2 1 input weight out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.prelu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_prelu, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_relu.cpp b/tools/pnnx/src/pass_level2/F_relu.cpp new file mode 100644 index 000000000000..b84ea31308f7 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_relu.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_relu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::relu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.relu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_relu, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_relu6.cpp b/tools/pnnx/src/pass_level2/F_relu6.cpp new file mode 100644 index 000000000000..d2b3379998e7 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_relu6.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_relu6 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::relu6 op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.relu6"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_relu6, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_rrelu.cpp b/tools/pnnx/src/pass_level2/F_rrelu.cpp new file mode 100644 index 000000000000..21a226f57f52 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_rrelu.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_rrelu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 lower +pnnx.Input input_2 0 1 upper +prim::Constant op_0 0 1 training value=False +prim::Constant op_1 0 1 generator value=None +aten::rrelu op_2 5 1 input lower upper training generator out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.rrelu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_rrelu, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_selu.cpp b/tools/pnnx/src/pass_level2/F_selu.cpp new file mode 100644 index 000000000000..592c3dd8ed77 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_selu.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_selu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::selu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.selu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_selu, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_sigmoid.cpp b/tools/pnnx/src/pass_level2/F_sigmoid.cpp new file mode 100644 index 000000000000..98df8197968b --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_sigmoid.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_sigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::sigmoid op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.sigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_sigmoid, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_silu.cpp b/tools/pnnx/src/pass_level2/F_silu.cpp new file mode 100644 index 000000000000..523cca2fdbc2 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_silu.cpp @@ -0,0 +1,62 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_silu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::silu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.silu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_silu, 10) + +class F_silu_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +aten::sigmoid op_0 1 1 input 166 +aten::mul op_1 2 1 input 166 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.silu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_silu_1, 9) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softmax.cpp b/tools/pnnx/src/pass_level2/F_softmax.cpp new file mode 100644 index 000000000000..af34a58957c8 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_softmax.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_softmax : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 dtype value=None +aten::softmax op_1 3 1 input dim dtype out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softmax"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmax, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softmin.cpp b/tools/pnnx/src/pass_level2/F_softmin.cpp new file mode 100644 index 000000000000..b4106a3fceaf --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_softmin.cpp @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_softmin : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +aten::neg op_0 1 1 input 6 +prim::Constant op_1 0 1 dtype value=None +aten::softmax op_2 3 1 6 dim dtype out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softmin"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmin, 9) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softplus.cpp b/tools/pnnx/src/pass_level2/F_softplus.cpp new file mode 100644 index 000000000000..c6a5279b4140 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_softplus.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_softplus : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 beta +pnnx.Input input_2 0 1 threshold +aten::softplus op_0 3 1 input beta threshold out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softplus"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softplus, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softshrink.cpp b/tools/pnnx/src/pass_level2/F_softshrink.cpp new file mode 100644 index 000000000000..286990bf2c57 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_softshrink.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_softshrink : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 lambd +aten::softshrink op_0 2 1 input lambd out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softshrink"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softshrink, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softsign.cpp b/tools/pnnx/src/pass_level2/F_softsign.cpp new file mode 100644 index 000000000000..4ec8ae9e520d --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_softsign.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_softsign : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 8 value=1 +prim::Constant op_1 0 1 7 value=1 +aten::abs op_2 1 1 input 6 +aten::add op_3 3 1 6 7 8 9 +aten::div op_4 2 1 input 9 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softsign"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softsign, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_tanh.cpp b/tools/pnnx/src/pass_level2/F_tanh.cpp new file mode 100644 index 000000000000..73c90fa99cd2 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_tanh.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_tanh : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::tanh op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.tanh"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_tanh, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_tanhshrink.cpp b/tools/pnnx/src/pass_level2/F_tanhshrink.cpp new file mode 100644 index 000000000000..d8d6c311fcd8 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_tanhshrink.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_tanhshrink : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 8 value=1 +aten::tanh op_1 1 1 input 7 +aten::sub op_2 3 1 input 7 8 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.tanhshrink"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_tanhshrink, 9) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_threshold.cpp b/tools/pnnx/src/pass_level2/F_threshold.cpp new file mode 100644 index 000000000000..a9407fed58da --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_threshold.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_threshold : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 threshold +pnnx.Input input_2 0 1 value +aten::threshold op_0 3 1 input threshold value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.threshold"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_threshold, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_upsample.cpp b/tools/pnnx/src/pass_level2/F_upsample.cpp new file mode 100644 index 000000000000..8739928912d8 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_upsample.cpp @@ -0,0 +1,409 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_upsample : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 scale_factor value=None +aten::upsample_nearest1d op_1 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "nearest"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample, 11) + +class F_upsample_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 scale_factor +prim::Constant op_0 0 1 size value=None +aten::upsample_nearest1d op_1 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["mode"] = "nearest"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_1, 11) + +class F_upsample_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=%align_corners +prim::Constant op_1 0 1 scale_factor value=None +aten::upsample_linear1d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "linear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_2, 11) + +class F_upsample_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 scale_factor +prim::Constant op_0 0 1 size value=None +prim::Constant op_1 0 1 align_corners value=%align_corners +aten::upsample_linear1d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "linear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_3, 11) + +class F_upsample_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=%align_corners +prim::Constant op_1 0 1 scale_h value=None +prim::Constant op_2 0 1 scale_w value=None +aten::upsample_bilinear2d op_3 5 1 input size align_corners scale_h scale_w out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_4, 11) + +class F_upsample_4_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=%align_corners +prim::Constant op_1 0 1 scale_factor value=None +aten::upsample_bilinear2d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_4_1, 11) + +class F_upsample_5 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 scale_factor +prim::Constant op_0 0 1 size value=None +prim::Constant op_1 0 1 align_corners value=%align_corners +aten::upsample_bilinear2d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_5, 11) + +class F_upsample_6 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=%align_corners +prim::Constant op_1 0 1 scale_h value=None +prim::Constant op_2 0 1 scale_w value=None +aten::upsample_bicubic2d op_3 5 1 input size align_corners scale_h scale_w out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bicubic"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_6, 11) + +class F_upsample_6_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=%align_corners +prim::Constant op_1 0 1 scale_factor value=None +aten::upsample_bicubic2d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bicubic"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_6_1, 11) + +class F_upsample_7 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 scale_factor +prim::Constant op_0 0 1 size value=None +prim::Constant op_1 0 1 align_corners value=%align_corners +aten::upsample_bicubic2d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "bicubic"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_7, 11) + +class F_upsample_8 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=%align_corners +prim::Constant op_1 0 1 scale_d value=None +prim::Constant op_2 0 1 scale_h value=None +prim::Constant op_3 0 1 scale_w value=None +aten::upsample_trilinear3d op_4 6 1 input size align_corners scale_d scale_h scale_w out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "trilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_8, 11) + +class F_upsample_8_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=%align_corners +prim::Constant op_1 0 1 scale_factor value=None +aten::upsample_trilinear3d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "trilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_8_1, 11) + +class F_upsample_9 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 scale_factor +prim::Constant op_0 0 1 size value=None +prim::Constant op_1 0 1 align_corners value=%align_corners +aten::upsample_trilinear3d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["align_corners"] = captured_params.at("align_corners"); + op->params["mode"] = "trilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_9, 11) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_upsample_bilinear.cpp b/tools/pnnx/src/pass_level2/F_upsample_bilinear.cpp new file mode 100644 index 000000000000..ab62a49200a0 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_upsample_bilinear.cpp @@ -0,0 +1,92 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_upsample_bilinear : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=1 +prim::Constant op_1 0 1 scale_h value=None +prim::Constant op_2 0 1 scale_w value=None +aten::upsample_bilinear2d op_3 5 1 input size align_corners scale_h scale_w out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_bilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_bilinear, 10) + +class F_upsample_bilinear_1_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 align_corners value=1 +prim::Constant op_1 0 1 scale_factor value=None +aten::upsample_bilinear2d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_bilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_bilinear_1_1, 10) + +class F_upsample_bilinear_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 scale_factor +prim::Constant op_0 0 1 size value=None +prim::Constant op_1 0 1 align_corners value=1 +aten::upsample_bilinear2d op_2 4 1 input size align_corners scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_bilinear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_bilinear_1, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_upsample_nearest.cpp b/tools/pnnx/src/pass_level2/F_upsample_nearest.cpp new file mode 100644 index 000000000000..c544e8065bbd --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_upsample_nearest.cpp @@ -0,0 +1,137 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_upsample_nearest : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 scale_h value=None +prim::Constant op_1 0 1 scale_w value=None +aten::upsample_nearest2d op_2 4 1 input size scale_h scale_w out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_nearest"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_nearest, 10) + +class F_upsample_nearest_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 scale_factor +prim::Constant op_0 0 1 size value=None +aten::upsample_nearest2d op_1 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_nearest"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_nearest_1, 10) + +class F_upsample_nearest_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 scale_d value=None +prim::Constant op_1 0 1 scale_h value=None +prim::Constant op_2 0 1 scale_w value=None +aten::upsample_nearest3d op_3 5 1 input size scale_d scale_h scale_w out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_nearest"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_nearest_2, 10) + +class F_upsample_nearest_2_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 scale_factor value=None +aten::upsample_nearest3d op_1 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_nearest"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_nearest_2_1, 10) + +class F_upsample_nearest_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 scale_factor +prim::Constant op_0 0 1 size value=None +aten::upsample_nearest3d op_1 3 1 input size scale_factor out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.upsample_nearest"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_upsample_nearest_3, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_contiguous.cpp b/tools/pnnx/src/pass_level2/Tensor_contiguous.cpp new file mode 100644 index 000000000000..6248ee8fb20a --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_contiguous.cpp @@ -0,0 +1,51 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_contiguous : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 memory_format value=%memory_format +aten::contiguous op_1 2 1 input memory_format out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.contiguous"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.at("memory_format").i == 0) + op->params["memory_format"] = "torch.contiguous_format"; + if (captured_params.at("memory_format").i == 1) + op->params["memory_format"] = "torch.preserve_format"; + if (captured_params.at("memory_format").i == 2) + op->params["memory_format"] = "torch.channels_last"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_contiguous, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp b/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp new file mode 100644 index 000000000000..7cddaa486980 --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_new_empty : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 size +prim::Constant op_0 0 1 dtype value=* +prim::Constant op_1 0 1 layout value=* +prim::Constant op_2 0 1 device value=* +prim::Constant op_3 0 1 pin_memory value=* +aten::new_empty op_4 6 1 input size dtype layout device pin_memory out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.new_empty"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_new_empty, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_repeat.cpp b/tools/pnnx/src/pass_level2/Tensor_repeat.cpp new file mode 100644 index 000000000000..c2f44cb1a927 --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_repeat.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_repeat : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 sizes +aten::repeat op_0 2 1 input sizes out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.repeat"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_repeat, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_reshape.cpp b/tools/pnnx/src/pass_level2/Tensor_reshape.cpp new file mode 100644 index 000000000000..cb5ee3d3fd19 --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_reshape.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_reshape : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 shape +aten::reshape op_0 2 1 input shape out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_select.cpp b/tools/pnnx/src/pass_level2/Tensor_select.cpp new file mode 100644 index 000000000000..3ab8a147bb04 --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_select.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_select : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +pnnx.Input input_2 0 1 index +aten::select op_0 3 1 input dim index out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.select"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_select, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_slice.cpp b/tools/pnnx/src/pass_level2/Tensor_slice.cpp new file mode 100644 index 000000000000..3ab352ae2e33 --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_slice.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_slice : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +pnnx.Input input_2 0 1 start +pnnx.Input input_3 0 1 end +pnnx.Input input_4 0 1 step +aten::slice op_0 5 1 input dim start end step out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.slice"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_slice, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_view.cpp b/tools/pnnx/src/pass_level2/Tensor_view.cpp new file mode 100644 index 000000000000..c15809849996 --- /dev/null +++ b/tools/pnnx/src/pass_level2/Tensor_view.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class Tensor_view : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 shape +aten::view op_0 2 1 input shape out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.view"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_view, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/nn_quantized_FloatFunctional.cpp b/tools/pnnx/src/pass_level2/nn_quantized_FloatFunctional.cpp new file mode 100644 index 000000000000..3368de745002 --- /dev/null +++ b/tools/pnnx/src/pass_level2/nn_quantized_FloatFunctional.cpp @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class nn_quantized_FloatFunctional_cat : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 tensors +pnnx.Input input_1 0 1 dim +pnnx.Input input_2 0 1 scale +pnnx.Input input_3 0 1 zero_point +quantized::cat op_0 4 1 tensors dim scale zero_point out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.quantized.cat"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_quantized_FloatFunctional_cat, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_cat.cpp b/tools/pnnx/src/pass_level2/torch_cat.cpp new file mode 100644 index 000000000000..b4d3b5e87d69 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_cat.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_cat : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 tensors +pnnx.Input input_1 0 1 dim +aten::cat op_0 2 1 tensors dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.cat"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_cat, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_chunk.cpp b/tools/pnnx/src/pass_level2/torch_chunk.cpp new file mode 100644 index 000000000000..11e23715d5d2 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_chunk.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_chunk : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 chunks +pnnx.Input input_2 0 1 dim +aten::chunk op_0 3 1 input chunks dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.chunk"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_chunk, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_clamp.cpp b/tools/pnnx/src/pass_level2/torch_clamp.cpp new file mode 100644 index 000000000000..310cbd72b14a --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_clamp.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_clamp : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 min +pnnx.Input input_2 0 1 max +aten::clamp op_0 3 1 input min max out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.clamp"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp, 20) + +class torch_clamp_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 min +pnnx.Input input_2 0 1 max +aten::clamp_ op_0 3 1 input min max out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.clamp"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_1, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_flatten.cpp b/tools/pnnx/src/pass_level2/torch_flatten.cpp new file mode 100644 index 000000000000..8760a0fabf45 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_flatten.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_flatten : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 start_dim +pnnx.Input input_2 0 1 end_dim +aten::flatten op_0 3 1 input start_dim end_dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.flatten"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flatten, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp new file mode 100644 index 000000000000..9176394c32b5 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_mean : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 keepdim value=%keepdim +prim::Constant op_1 0 1 dtype value=* +aten::mean op_2 4 1 input dim keepdim dtype out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.mean"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_permute.cpp b/tools/pnnx/src/pass_level2/torch_permute.cpp new file mode 100644 index 000000000000..424d260cb229 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_permute.cpp @@ -0,0 +1,47 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +#include + +namespace pnnx { + +class torch_permute : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dims +aten::permute op_0 2 1 input dims out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { +#if TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 9 + return "torch.permute"; +#else + return "Tensor.permute"; +#endif + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_permute, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_split.cpp b/tools/pnnx/src/pass_level2/torch_split.cpp new file mode 100644 index 000000000000..565fd3bcf99f --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_split.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_split : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 tensor +pnnx.Input input_1 0 1 split_size_or_sections +pnnx.Input input_2 0 1 dim +aten::split op_0 3 1 tensor split_size_or_sections dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.split"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_split, 20) + +class torch_split_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 tensor +pnnx.Input input_1 0 1 split_size_or_sections +pnnx.Input input_2 0 1 dim +aten::split_with_sizes op_0 3 1 tensor split_size_or_sections dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.split"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_split_1, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_squeeze.cpp b/tools/pnnx/src/pass_level2/torch_squeeze.cpp new file mode 100644 index 000000000000..95289b6ff806 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_squeeze.cpp @@ -0,0 +1,62 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_squeeze : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +aten::squeeze op_0 2 1 input dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.squeeze"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze, 20) + +class torch_squeeze_0 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +aten::squeeze op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.squeeze"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze_0, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_sum.cpp b/tools/pnnx/src/pass_level2/torch_sum.cpp new file mode 100644 index 000000000000..c8418bb42f6a --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_sum.cpp @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_sum : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 keepdim value=%keepdim +prim::Constant op_1 0 1 dtype value=* +aten::sum op_2 4 1 input dim keepdim dtype out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.sum"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_transpose.cpp b/tools/pnnx/src/pass_level2/torch_transpose.cpp new file mode 100644 index 000000000000..4442320d2077 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_transpose.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_transpose : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim0 +pnnx.Input input_2 0 1 dim1 +aten::transpose op_0 3 1 input dim0 dim1 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.transpose"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_transpose, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp b/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp new file mode 100644 index 000000000000..9acffa1d041c --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_unsqueeze : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +aten::unsqueeze op_0 2 1 input dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.unsqueeze"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unsqueeze, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3.cpp b/tools/pnnx/src/pass_level3.cpp new file mode 100644 index 000000000000..c2d44c61e227 --- /dev/null +++ b/tools/pnnx/src/pass_level3.cpp @@ -0,0 +1,52 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level3.h" + +#include "pass_level3/eliminate_tuple_pair.h" +#include "pass_level3/expand_quantization_modules.h" +#include "pass_level3/fuse_attribute_expression.h" +#include "pass_level3/fuse_cat_tensors.h" +#include "pass_level3/fuse_chunk_split_unpack.h" +#include "pass_level3/fuse_expression.h" +#include "pass_level3/fuse_rnn_unpack.h" + +// #include "pass_level4/canonicalize.h" +// #include "pass_level4/fuse_custom_op.h" +// #include "pass_level4/dead_code_elimination.h" + +namespace pnnx { + +void pass_level3(Graph& g) +{ + fuse_cat_tensors(g); + + fuse_chunk_split_unpack(g); + + fuse_rnn_unpack(g); + + expand_quantization_modules(g); + + fuse_attribute_expression(g); + + eliminate_tuple_pair(g); + + fuse_expression(g); + + // dead_code_elimination(g); + + // canonicalize(g); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3.h b/tools/pnnx/src/pass_level3.h new file mode 100644 index 000000000000..23c7dd8a8571 --- /dev/null +++ b/tools/pnnx/src/pass_level3.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_PASS_LEVEL3_H +#define PNNX_PASS_LEVEL3_H + +#include "ir.h" + +namespace pnnx { + +void pass_level3(Graph& g); + +} // namespace pnnx + +#endif // PNNX_PASS_LEVEL3_H diff --git a/tools/pnnx/src/pass_level3/eliminate_tuple_pair.cpp b/tools/pnnx/src/pass_level3/eliminate_tuple_pair.cpp new file mode 100644 index 000000000000..013538f65ff2 --- /dev/null +++ b/tools/pnnx/src/pass_level3/eliminate_tuple_pair.cpp @@ -0,0 +1,96 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_tuple_pair.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void eliminate_tuple_pair(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "prim::TupleConstruct") + continue; + + if (op->outputs[0]->consumers.size() != 1) + continue; + + Operator* op2 = op->outputs[0]->consumers[0]; + if (op2->type != "prim::TupleUnpack") + continue; + + if (op->inputs.size() != op2->outputs.size()) + continue; + + matched = true; + + const size_t count = op->inputs.size(); + + for (size_t j = 0; j < count; j++) + { + op->inputs[j]->remove_consumer(op); + + for (auto& x : op2->outputs[j]->consumers) + { + op->inputs[j]->consumers.push_back(x); + + for (size_t k = 0; k < x->inputs.size(); k++) + { + if (x->inputs[k] == op2->outputs[j]) + x->inputs[k] = op->inputs[j]; + } + } + + op2->outputs[j]->producer = 0; + op2->outputs[j]->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op2->outputs[j])); + delete op2->outputs[j]; + } + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op->outputs[0])); + delete op->outputs[0]; + + op->inputs.clear(); + op->outputs.clear(); + + op2->inputs.clear(); + op2->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op)); + + delete op; + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + + delete op2; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/eliminate_tuple_pair.h b/tools/pnnx/src/pass_level3/eliminate_tuple_pair.h new file mode 100644 index 000000000000..df70eda27a68 --- /dev/null +++ b/tools/pnnx/src/pass_level3/eliminate_tuple_pair.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void eliminate_tuple_pair(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/expand_quantization_modules.cpp b/tools/pnnx/src/pass_level3/expand_quantization_modules.cpp new file mode 100644 index 000000000000..d42070806bdc --- /dev/null +++ b/tools/pnnx/src/pass_level3/expand_quantization_modules.cpp @@ -0,0 +1,73 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "expand_quantization_modules.h" +#include +#include "pass_level2.h" + +namespace pnnx { + +void expand_quantization_modules(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "nn.intrinsic.quantized.ConvReLU2d") + continue; + + matched = true; + + // expand to nn.quantized.Conv2d + nn.ReLU + op->type = "nn.quantized.Conv2d"; + + // insert new operator before all output consumers + const Operator* cur = 0; + { + int cur_index = graph.ops.size() - 1; + for (auto& c : op->outputs[0]->consumers) + { + int c_index = std::find(graph.ops.begin(), graph.ops.end(), c) - graph.ops.begin(); + cur_index = std::min(cur_index, c_index); + } + + cur = graph.ops[cur_index]; + } + + Operator* op_relu = graph.new_operator_before("nn.ReLU", op->name + "_relu", cur); + + Operand* r0 = graph.new_operand(op->name + "_norelu"); + + r0->producer = op; + r0->consumers.push_back(op_relu); + + op_relu->inputs.push_back(r0); + op_relu->outputs.push_back(op->outputs[0]); + op_relu->outputs[0]->producer = op_relu; + + op->outputs[0] = r0; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/expand_quantization_modules.h b/tools/pnnx/src/pass_level3/expand_quantization_modules.h new file mode 100644 index 000000000000..a57cbc8a20f4 --- /dev/null +++ b/tools/pnnx/src/pass_level3/expand_quantization_modules.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void expand_quantization_modules(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_attribute_expression.cpp b/tools/pnnx/src/pass_level3/fuse_attribute_expression.cpp new file mode 100644 index 000000000000..be409705fb5a --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_attribute_expression.cpp @@ -0,0 +1,197 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_attribute_expression.h" +#include +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_attribute_expression(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "pnnx.Attribute") + continue; + + if (op->outputs.size() != 1) + continue; + + if (op->outputs[0]->consumers.size() != 1) + continue; + + Operator* op2 = op->outputs[0]->consumers[0]; + Operator* op3 = 0; + Operator* op4 = 0; + + float y = 0.f; + float z = 0.f; + + if (op2->type == "aten::add" || op2->type == "aten::sub") + { + if (op2->inputs[0] != op->outputs[0]) + continue; + + op3 = op2->inputs[1]->producer; + if (op3->type != "prim::Constant") + continue; + + if (op3->params["value"].type == 2) + { + y = op3->params["value"].i; + } + else if (op3->params["value"].type == 3) + { + y = op3->params["value"].f; + } + else + { + // not a scalar + continue; + } + + op4 = op2->inputs[2]->producer; + if (op4->type != "prim::Constant") + continue; + + if (op4->params["value"].type == 2) + { + z = op4->params["value"].i; + } + else if (op4->params["value"].type == 3) + { + z = op4->params["value"].f; + } + else + { + // not a scalar + continue; + } + } + else if (op2->type == "aten::mul" || op2->type == "aten::div" || op2->type == "aten::pow") + { + if (op2->inputs[0] != op->outputs[0]) + continue; + + op3 = op2->inputs[1]->producer; + if (op3->type != "prim::Constant") + continue; + + if (op3->params["value"].type == 2) + { + y = op3->params["value"].i; + } + else if (op3->params["value"].type == 3) + { + y = op3->params["value"].f; + } + else + { + // not a scalar + continue; + } + } + else + { + // todo more operator type + continue; + } + + matched = true; + + // apply mul + { + auto it = op->attrs.begin(); + std::string attr_key = it->first; + const Attribute& attr = it->second; + + float* weight = (float*)attr.data.data(); + const int weight_size = attr.data.size() / sizeof(float); + + if (op2->type == "aten::add") + { + for (int i = 0; i < weight_size; i++) + weight[i] += y * z; + } + else if (op2->type == "aten::sub") + { + for (int i = 0; i < weight_size; i++) + weight[i] -= y * z; + } + else if (op2->type == "aten::mul") + { + for (int i = 0; i < weight_size; i++) + weight[i] *= y; + } + else if (op2->type == "aten::div") + { + for (int i = 0; i < weight_size; i++) + weight[i] /= y; + } + else if (op2->type == "aten::pow") + { + for (int i = 0; i < weight_size; i++) + weight[i] = (float)pow(weight[i], y); + } + + op->attrs[attr_key] = attr; + } + + op2->outputs[0]->producer = op; + + for (auto& x : op2->inputs) + { + x->producer = 0; + x->remove_consumer(op2); + } + + op->outputs = op2->outputs; + + op2->inputs.clear(); + op2->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + + delete op2; + + if (op3 && op3->outputs[0]->consumers.empty()) + { + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op3)); + + delete op3; + } + + if (op4 && op4->outputs[0]->consumers.empty()) + { + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op4)); + + delete op4; + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_attribute_expression.h b/tools/pnnx/src/pass_level3/fuse_attribute_expression.h new file mode 100644 index 000000000000..348542d769f8 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_attribute_expression.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_attribute_expression(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_cat_tensors.cpp b/tools/pnnx/src/pass_level3/fuse_cat_tensors.cpp new file mode 100644 index 000000000000..ee14d4f7d913 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_cat_tensors.cpp @@ -0,0 +1,82 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_cat_tensors.h" +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_cat_tensors(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "torch.cat") + continue; + + if (op->inputs.size() < 1) + continue; + + if (op->inputs[0]->consumers.size() != 1) + continue; + + Operator* op2 = op->inputs[0]->producer; + if (op2->type != "prim::ListConstruct") + continue; + + matched = true; + + op->inputs[0]->producer = 0; + op->inputs[0]->remove_consumer(op); + + std::vector new_inputs; + std::vector new_inputnames(op2->inputs.size()); + for (auto& x : op2->inputs) + { + x->remove_consumer(op2); + x->consumers.push_back(op); + new_inputs.push_back(x); + } + + for (size_t j = 1; j < op->inputs.size(); j++) + { + new_inputs.push_back(op->inputs[j]); + new_inputnames.push_back(op->inputnames[j]); + } + + op->inputs = new_inputs; + op->inputnames = new_inputnames; + + op2->inputs.clear(); + op2->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + + delete op2; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_cat_tensors.h b/tools/pnnx/src/pass_level3/fuse_cat_tensors.h new file mode 100644 index 000000000000..fd85cd405347 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_cat_tensors.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_cat_tensors(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.cpp new file mode 100644 index 000000000000..a020f4dd506a --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.cpp @@ -0,0 +1,71 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_chunk_split_unpack.h" +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_chunk_split_unpack(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "torch.chunk" && op->type != "torch.split") + continue; + + if (op->outputs.size() != 1) + continue; + + if (op->outputs[0]->consumers.size() != 1) + continue; + + Operator* op2 = op->outputs[0]->consumers[0]; + if (op2->type != "prim::ListUnpack") + continue; + + matched = true; + + op->outputs[0]->producer = 0; + op->outputs[0]->remove_consumer(op2); + + for (auto& x : op2->outputs) + { + x->producer = op; + } + + op->outputs = op2->outputs; + + op2->inputs.clear(); + op2->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + + delete op2; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.h b/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.h new file mode 100644 index 000000000000..06949e6f0d74 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_chunk_split_unpack(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp new file mode 100644 index 000000000000..30973ffbc9fd --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -0,0 +1,392 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_expression.h" + +#include + +namespace pnnx { + +static bool operand_maybe_tensor(Operand* operand) +{ + Operator* op = operand->producer; + + if (op->type == "prim::Constant") + { + const Parameter& param = op->params["value"]; + if (param.type == 0 || param.type == 1 || param.type == 2 || param.type == 3 || param.type == 4) + { + return false; + } + else + { + return true; + } + } + + if (op->type == "prim::NumToTensor") + { + return operand_maybe_tensor(op->inputs[0]); + } + + if (op->type == "prim::ListConstruct") + { + return false; + } + + if (op->type == "aten::size") + { + return false; + } + + if (op->type == "aten::Int") + { + return operand_maybe_tensor(op->inputs[0]); + } + + if (op->type == "aten::to" || op->type == "aten::detach") + { + return operand_maybe_tensor(op->inputs[0]); + } + + if (op->type == "aten::floor_divide" || op->type == "aten::mul" || op->type == "aten::div" || op->type == "aten::div_" || op->type == "aten::pow") + { + return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]); + } + + if (op->type == "aten::add" || op->type == "aten::add_" || op->type == "aten::sub" || op->type == "aten::rsub") + { + return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]) || operand_maybe_tensor(op->inputs[2]); + } + + if (op->type == "aten::sqrt" || op->type == "aten::rsqrt" || op->type == "aten::neg") + { + return operand_maybe_tensor(op->inputs[0]); + } + + return true; +} + +static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector& inputs, bool checksubgraph = true) +{ + // fprintf(stderr, "fuse_expression %s\n", operand->name.c_str()); + + Operator* op = operand->producer; + + if (checksubgraph && operand_maybe_tensor(operand)) + { + if (op->outputs.size() > 1 || op->outputs[0]->consumers.size() > 1) + { + auto it = std::find(inputs.begin(), inputs.end(), operand); + if (it == inputs.end()) + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)inputs.size()); + expr += tmp; + + inputs.push_back(operand); + } + else + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)(it - inputs.begin())); + expr += tmp; + } + + return; + } + } + + if (op->type == "prim::Constant") + { + const Parameter& param = op->params["value"]; + // fprintf(stderr, "fuse_expression prim::Constant %d\n", param.type); + if (param.type == 0) + { + expr += "None"; + } + else if (param.type == 1) + { + expr += param.b ? "True" : "False"; + } + else if (param.type == 2) + { + char tmp[32]; + sprintf(tmp, "%d", param.i); + expr += tmp; + } + else if (param.type == 3) + { + char tmp[32]; + sprintf(tmp, "%e", param.f); + expr += tmp; + } + else if (param.type == 4) + { + expr += param.s; + } + else + { + auto it = std::find(inputs.begin(), inputs.end(), operand); + if (it == inputs.end()) + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)inputs.size()); + expr += tmp; + + inputs.push_back(operand); + } + else + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)(it - inputs.begin())); + expr += tmp; + } + } + } + else if (op->type == "prim::NumToTensor") + { + fuse_expression(graph, op->inputs[0], expr, inputs); + } + else if (op->type == "prim::ListConstruct") + { + expr += "["; + for (int i = 0; i < (int)op->inputs.size() - 1; i++) + { + fuse_expression(graph, op->inputs[i], expr, inputs); + expr += ","; + } + if (op->inputs.size() > 0) + { + fuse_expression(graph, op->inputs[op->inputs.size() - 1], expr, inputs); + } + expr += "]"; + } + else if (op->type == "aten::size") + { + expr += "size("; + fuse_expression(graph, op->inputs[0], expr, inputs); + expr += ","; + fuse_expression(graph, op->inputs[1], expr, inputs); + expr += ")"; + } + else if (op->type == "aten::Int") + { + expr += "int("; + fuse_expression(graph, op->inputs[0], expr, inputs); + expr += ")"; + } + else if (op->type == "aten::to" || op->type == "aten::detach") + { + fuse_expression(graph, op->inputs[0], expr, inputs); + } + else if (op->type == "aten::floor_divide" || op->type == "aten::mul" || op->type == "aten::div" || op->type == "aten::div_" || op->type == "aten::pow") + { + std::string mathop = op->type.substr(6); + if (mathop == "div_") + mathop = "div"; + + expr += mathop; + expr += "("; + fuse_expression(graph, op->inputs[0], expr, inputs); + expr += ","; + fuse_expression(graph, op->inputs[1], expr, inputs); + expr += ")"; + } + else if (op->type == "aten::add" || op->type == "aten::add_" || op->type == "aten::sub") + { + std::string mathop = op->type.substr(6); + if (mathop == "add_") + mathop = "add"; + + expr += mathop; + expr += "("; + fuse_expression(graph, op->inputs[0], expr, inputs); + expr += ","; + + std::string expr1; + std::string expr2; + fuse_expression(graph, op->inputs[1], expr1, inputs); + fuse_expression(graph, op->inputs[2], expr2, inputs); + + if (expr2 == "1") + { + expr += expr1; + } + else + { + expr += ","; + expr += "mul("; + expr += expr1; + expr += ","; + expr += expr2; + expr += ")"; + } + + expr += ")"; + } + else if (op->type == "aten::rsub") + { + expr += "sub("; + std::string expr1; + std::string expr2; + fuse_expression(graph, op->inputs[1], expr1, inputs); + fuse_expression(graph, op->inputs[2], expr2, inputs); + + if (expr2 == "1") + { + expr += expr1; + } + else + { + expr += ","; + expr += "mul("; + expr += expr1; + expr += ","; + expr += expr2; + expr += ")"; + } + + expr += ","; + fuse_expression(graph, op->inputs[0], expr, inputs); + expr += ")"; + } + else if (op->type == "aten::sqrt") + { + expr += "sqrt("; + fuse_expression(graph, op->inputs[0], expr, inputs); + expr += ")"; + } + else if (op->type == "aten::rsqrt") + { + expr += "rsqrt("; + fuse_expression(graph, op->inputs[0], expr, inputs); + expr += ")"; + } + else if (op->type == "aten::neg") + { + expr += "neg("; + fuse_expression(graph, op->inputs[0], expr, inputs); + expr += ")"; + } + else + { + auto it = std::find(inputs.begin(), inputs.end(), operand); + if (it == inputs.end()) + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)inputs.size()); + expr += tmp; + + inputs.push_back(operand); + } + else + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)(it - inputs.begin())); + expr += tmp; + } + } +} + +void fuse_expression(Graph& graph) +{ + int pnnx_expr_index = 0; + + for (;;) + { + bool need_fuse = false; + + // build expression via reverse order + for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + { + Operator* op = graph.ops[i]; + + if (op->type == "prim::Constant") + { + need_fuse = true; + } + if (op->type == "prim::NumToTensor") + { + need_fuse = true; + } + if (op->type == "prim::ListConstruct") + { + need_fuse = true; + } + if (op->type == "aten::size") + { + need_fuse = true; + } + if (op->type == "aten::Int") + { + need_fuse = true; + } + if (op->type == "aten::to" || op->type == "aten::detach") + { + need_fuse = true; + } + if (op->type == "aten::floor_divide" || op->type == "aten::add" || op->type == "aten::add_" || op->type == "aten::sub" || op->type == "aten::mul" || op->type == "aten::div" || op->type == "aten::div_" || op->type == "aten::sqrt" || op->type == "aten::rsub" || op->type == "aten::rsqrt" || op->type == "aten::neg" || op->type == "aten::pow") + { + need_fuse = true; + } + + if (need_fuse) + { + std::string expr; + std::vector inputs; + fuse_expression(graph, op->outputs[0], expr, inputs, false); + // fprintf(stderr, "expr = %s\n", expr.c_str()); + + // lets rewrite graph + char name[32]; + sprintf(name, "pnnx_expr_%d", pnnx_expr_index++); + + op->type = "pnnx.Expression"; + op->name = name; + + op->params.clear(); + op->attrs.clear(); + + op->params["expr"] = expr; + + // fix input output + for (Operand* operand : op->inputs) + { + operand->consumers.erase(std::find(operand->consumers.begin(), operand->consumers.end(), op)); + } + + op->inputs = inputs; + + for (Operand* operand : op->inputs) + { + operand->consumers.push_back(op); + } + + break; + } + } + + if (!need_fuse) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_expression.h b/tools/pnnx/src/pass_level3/fuse_expression.h new file mode 100644 index 000000000000..2bb8615f68e8 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_expression.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_expression(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_rnn_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_rnn_unpack.cpp new file mode 100644 index 000000000000..42ba42234395 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_rnn_unpack.cpp @@ -0,0 +1,86 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_rnn_unpack.h" +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_rnn_unpack(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "nn.RNN" && op->type != "nn.LSTM" && op->type != "nn.GRU") + continue; + + if (op->outputs.size() != 1) + continue; + + if (op->outputs[0]->consumers.size() != 1) + continue; + + Operator* op2 = op->outputs[0]->consumers[0]; + if (op2->type != "prim::TupleUnpack") + continue; + + matched = true; + + op->outputs[0]->producer = 0; + op->outputs[0]->remove_consumer(op2); + + for (auto& x : op2->outputs) + { + x->producer = op; + } + + op->outputs = op2->outputs; + + // outputs may be swapped, fix the ugly order + if (op->params.find("pnnx_rnn_output_swapped") != op->params.end() && op->params.at("pnnx_rnn_output_swapped").i == 1) + { + op->params.erase("pnnx_rnn_output_swapped"); + if (op->type == "nn.RNN" || op->type == "nn.GRU") + { + std::swap(op->outputs[0], op->outputs[1]); + } + if (op->type == "nn.LSTM") + { + std::swap(op->outputs[0], op->outputs[2]); + std::swap(op->outputs[1], op->outputs[2]); + } + } + + op2->inputs.clear(); + op2->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); + + delete op2; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_rnn_unpack.h b/tools/pnnx/src/pass_level3/fuse_rnn_unpack.h new file mode 100644 index 000000000000..79f1f65df5a9 --- /dev/null +++ b/tools/pnnx/src/pass_level3/fuse_rnn_unpack.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_rnn_unpack(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level4.cpp b/tools/pnnx/src/pass_level4.cpp new file mode 100644 index 000000000000..8ebcb4bfee61 --- /dev/null +++ b/tools/pnnx/src/pass_level4.cpp @@ -0,0 +1,32 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level4.h" + +#include "pass_level4/canonicalize.h" +#include "pass_level4/fuse_custom_op.h" +#include "pass_level4/dead_code_elimination.h" + +namespace pnnx { + +void pass_level4(Graph& g) +{ + fuse_custom_op(g); + + dead_code_elimination(g); + + canonicalize(g); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level4.h b/tools/pnnx/src/pass_level4.h new file mode 100644 index 000000000000..2f6daf73bea0 --- /dev/null +++ b/tools/pnnx/src/pass_level4.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_PASS_LEVEL4_H +#define PNNX_PASS_LEVEL4_H + +#include "ir.h" + +namespace pnnx { + +void pass_level4(Graph& g); + +} // namespace pnnx + +#endif // PNNX_PASS_LEVEL4_H diff --git a/tools/pnnx/src/pass_level4/canonicalize.cpp b/tools/pnnx/src/pass_level4/canonicalize.cpp new file mode 100644 index 000000000000..65017e2fb3f8 --- /dev/null +++ b/tools/pnnx/src/pass_level4/canonicalize.cpp @@ -0,0 +1,34 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "canonicalize.h" + +namespace pnnx { + +void canonicalize(Graph& graph) +{ + int i = 0; + + for (Operator* op : graph.ops) + { + for (Operand* operand : op->outputs) + { + operand->name = std::to_string(i); + + i++; + } + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level4/canonicalize.h b/tools/pnnx/src/pass_level4/canonicalize.h new file mode 100644 index 000000000000..e65f19e1c3b9 --- /dev/null +++ b/tools/pnnx/src/pass_level4/canonicalize.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void canonicalize(Graph& graph); + +} diff --git a/tools/pnnx/src/pass_level4/dead_code_elimination.cpp b/tools/pnnx/src/pass_level4/dead_code_elimination.cpp new file mode 100644 index 000000000000..800bd6715729 --- /dev/null +++ b/tools/pnnx/src/pass_level4/dead_code_elimination.cpp @@ -0,0 +1,99 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "dead_code_elimination.h" + +namespace pnnx { + +void dead_code_elimination(Graph& graph) +{ + // dead op elimination + for (;;) + { + bool need_eliminate = false; + + for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + { + Operator* op = graph.ops[i]; + + if (op->type == "pnnx.Output") + continue; + + int consumers = 0; + for (const Operand* operand : op->outputs) + { + consumers += (int)operand->consumers.size(); + } + + if (consumers == 0) + { + need_eliminate = true; + + // fprintf(stderr, "delete %s %s\n", op->type.c_str(), op->name.c_str()); + + for (Operand* operand : op->inputs) + { + operand->remove_consumer(op); + } + + op->inputs.clear(); + + for (Operand* operand : op->outputs) + { + operand->producer = 0; + } + + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + } + + if (!need_eliminate) + break; + } + + // dead operand elimination + for (;;) + { + bool need_eliminate = false; + + for (int i = (int)graph.operands.size() - 1; i >= 0; i--) + { + Operand* operand = graph.operands[i]; + + int consumers = (int)operand->consumers.size(); + + if (operand->producer == 0 && consumers == 0) + { + need_eliminate = true; + + // fprintf(stderr, "delete operand %s\n", operand->name.c_str()); + + graph.operands.erase(graph.operands.begin() + i); + delete operand; + + break; + } + } + + if (!need_eliminate) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level4/dead_code_elimination.h b/tools/pnnx/src/pass_level4/dead_code_elimination.h new file mode 100644 index 000000000000..145b40904e5b --- /dev/null +++ b/tools/pnnx/src/pass_level4/dead_code_elimination.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void dead_code_elimination(Graph& graph); + +} diff --git a/tools/pnnx/src/pass_level4/fuse_custom_op.cpp b/tools/pnnx/src/pass_level4/fuse_custom_op.cpp new file mode 100644 index 000000000000..1b999312ba52 --- /dev/null +++ b/tools/pnnx/src/pass_level4/fuse_custom_op.cpp @@ -0,0 +1,94 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_custom_op.h" + +#include + +namespace pnnx { + +void fuse_custom_op(Graph& graph) +{ + std::set custom_ops; + + for (;;) + { + bool need_fuse = false; + + // fuse in reverse order + for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + { + Operator* op = graph.ops[i]; + + if (op->type.find("::") == std::string::npos) + continue; + + std::string op_type_namespace = op->type.substr(0, op->type.find_first_of(':')); + + if (op_type_namespace == "aten" || op_type_namespace == "prim") + continue; + + custom_ops.insert(op->type); + + std::string op_type_name = op->type.substr(op->type.find_last_of(':') + 1); + + need_fuse = true; + + op->type = std::string("pnnx.custom_op.") + op_type_namespace + '.' + op_type_name; + + std::vector new_inputs; + std::vector new_inputnames; + for (size_t j = 0; j < op->inputs.size(); j++) + { + Operator* arg = op->inputs[j]->producer; + + if (!arg->inputs.empty()) + { + new_inputs.push_back(op->inputs[j]); + new_inputnames.push_back(std::string("arg_") + std::to_string(j)); + continue; + } + + if (arg->type == "prim::Constant") + { + op->params[std::string("arg_") + std::to_string(j)] = arg->params["value"]; + op->inputs[j]->remove_consumer(op); + } + else if (arg->type == "pnnx.Expression") + { + op->params[std::string("arg_") + std::to_string(j)] = Parameter::parse_from_string(arg->params["expr"].s); + op->inputs[j]->remove_consumer(op); + } + else + { + new_inputs.push_back(op->inputs[j]); + new_inputnames.push_back(std::string("arg_") + std::to_string(j)); + } + } + + op->inputs = new_inputs; + op->inputnames = new_inputnames; + } + + if (!need_fuse) + break; + } + + for (auto x : custom_ops) + { + fprintf(stderr, "custom op = %s\n", x.c_str()); + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level4/fuse_custom_op.h b/tools/pnnx/src/pass_level4/fuse_custom_op.h new file mode 100644 index 000000000000..e1668c817e93 --- /dev/null +++ b/tools/pnnx/src/pass_level4/fuse_custom_op.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_custom_op(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp new file mode 100644 index 000000000000..54d338218d96 --- /dev/null +++ b/tools/pnnx/src/pass_level5.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level5.h" + +#include "pass_level5/eliminate_slice.h" +#include "pass_level5/eliminate_view_reshape.h" +#include "pass_level5/eval_expression.h" +#include "pass_level5/fuse_channel_shuffle.h" +#include "pass_level5/fuse_constant_expression.h" +#include "pass_level5/fuse_conv2d_batchnorm2d.h" +#include "pass_level5/fuse_convtranspose2d_batchnorm2d.h" +#include "pass_level5/fuse_contiguous_view.h" +#include "pass_level5/fuse_linear_batchnorm1d.h" +#include "pass_level5/fuse_slice_indices.h" +#include "pass_level4/dead_code_elimination.h" +#include "pass_level4/canonicalize.h" + +namespace pnnx { + +void pass_level5(Graph& g) +{ + eval_expression(g); + + fuse_constant_expression(g); + + eliminate_slice(g); + + fuse_slice_indices(g); + + fuse_conv2d_batchnorm2d(g); + + fuse_convtranspose2d_batchnorm2d(g); + + fuse_linear_batchnorm1d(g); + + fuse_contiguous_view(g); + + eliminate_view_reshape(g); + + fuse_channel_shuffle(g); + + dead_code_elimination(g); + + canonicalize(g); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5.h b/tools/pnnx/src/pass_level5.h new file mode 100644 index 000000000000..14228516d07d --- /dev/null +++ b/tools/pnnx/src/pass_level5.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_PASS_LEVEL5_H +#define PNNX_PASS_LEVEL5_H + +#include "ir.h" + +namespace pnnx { + +void pass_level5(Graph& g); + +} // namespace pnnx + +#endif // PNNX_PASS_LEVEL5_H diff --git a/tools/pnnx/src/pass_level5/eliminate_slice.cpp b/tools/pnnx/src/pass_level5/eliminate_slice.cpp new file mode 100644 index 000000000000..afe8ae073563 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_slice.cpp @@ -0,0 +1,86 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_slice.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void eliminate_slice(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "Tensor.slice") + continue; + + if (op->inputs.size() != 1) + continue; + + int start = op->params.at("start").i; + int end = op->params.at("end").i; + int step = op->params.at("step").i; + + if (start == 0 && end == -1 && step == 1) + { + // delete noop-like slice + matched = true; + + for (auto& x : op->inputs) + { + x->remove_consumer(op); + } + + Operand* slice_out = op->outputs[0]; + + for (auto& x : slice_out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == slice_out) + x->inputs[j] = op->inputs[0]; + } + + op->inputs[0]->consumers.push_back(x); + } + + slice_out->producer = 0; + slice_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), slice_out)); + delete slice_out; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_slice.h b/tools/pnnx/src/pass_level5/eliminate_slice.h new file mode 100644 index 000000000000..a90ed96f4e98 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_slice.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void eliminate_slice(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp b/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp new file mode 100644 index 000000000000..342b4048caf8 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp @@ -0,0 +1,83 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_view_reshape.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +void eliminate_view_reshape(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "Tensor.view" && op->type != "Tensor.reshape") + continue; + + const std::vector& input_shape = op->inputs[0]->shape; + const std::vector& output_shape = op->outputs[0]->shape; + if (input_shape != output_shape) + continue; + + if (input_shape.empty()) + continue; + + matched = true; + + for (auto& x : op->inputs) + { + x->remove_consumer(op); + } + + Operand* op_out = op->outputs[0]; + + for (auto& x : op_out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == op_out) + x->inputs[j] = op->inputs[0]; + } + + op->inputs[0]->consumers.push_back(x); + } + + op_out->producer = 0; + op_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_out)); + delete op_out; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_view_reshape.h b/tools/pnnx/src/pass_level5/eliminate_view_reshape.h new file mode 100644 index 000000000000..e3996354484e --- /dev/null +++ b/tools/pnnx/src/pass_level5/eliminate_view_reshape.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void eliminate_view_reshape(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp new file mode 100644 index 000000000000..0ee32e172334 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -0,0 +1,373 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eval_expression.h" + +#include + +#include +#include +#include +#include +#include +#include + +namespace pnnx { + +static bool token_is_argument(const std::string& t) +{ + if (t[0] != '@' || t.size() < 2) + return false; + + for (size_t i = 1; i < t.size(); i++) + { + if (t[i] < '0' || t[i] > '9') + return false; + } + + return true; +} + +static bool token_is_literal(const std::string& t) +{ + std::istringstream iss(t); + float f; + iss >> std::noskipws >> f; + return iss.eof() && !iss.fail(); + + // for (size_t i = 0; i < t.size(); i++) + // { + // if (i == 0 && t[i] == '-') + // continue; + // + // if (t[i] < '0' || t[i] > '9') + // { + // if (t[i] != '.' && t[i] != 'e') + // return false; + // } + // } + // + // return true; +} + +static std::string eval_expression(const Operator* op) +{ + std::string expr = op->params.at("expr").s; + + // fprintf(stderr, "eval_expression %s\n", expr.c_str()); + + // split into tokens + std::vector tokens; + { + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '[') // list + { + t += ch; + tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + } + + // scan and stack + std::stack exprstack; + for (int i = (int)tokens.size() - 1; i >= 0; i--) + { + const std::string& t = tokens[i]; + + if (t == "size") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + if (token_is_argument(a) && token_is_literal(b)) + { + int input_index = std::stoi(a.substr(1)); + if (op->inputs[input_index]->shape.empty()) + { + std::string r = std::string("size(") + a + "," + b + ")"; + exprstack.push(r); + } + else + { + int bi = std::stoi(b); + int r = op->inputs[input_index]->shape[bi]; + exprstack.push(std::to_string(r)); + } + } + else + { + std::string r = std::string("size(") + a + "," + b + ")"; + exprstack.push(r); + } + } + else if (t == "int" || t == "sqrt" || t == "rsqrt" || t == "neg") + { + std::string a = exprstack.top(); + exprstack.pop(); + + if (token_is_literal(a)) + { + float af = std::stof(a); + + if (t == "int") + { + int r = int(af); + exprstack.push(std::to_string(r)); + } + if (t == "sqrt") + { + float r = sqrt(af); + exprstack.push(std::to_string(r)); + } + if (t == "rsqrt") + { + float r = 1.f / sqrt(af); + exprstack.push(std::to_string(r)); + } + if (t == "neg") + { + float r = -af; + exprstack.push(std::to_string(r)); + } + } + else + { + std::string r = t + "(" + a + ")"; + exprstack.push(r); + } + } + else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "pow") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + if (token_is_literal(a) && token_is_literal(b)) + { + float af = std::stof(a); + float bf = std::stof(b); + + if (t == "add") + { + float r = af + bf; + exprstack.push(std::to_string(r)); + } + if (t == "sub") + { + float r = af - bf; + exprstack.push(std::to_string(r)); + } + if (t == "mul") + { + float r = af * bf; + exprstack.push(std::to_string(r)); + } + if (t == "div") + { + float r = af / bf; + exprstack.push(std::to_string(r)); + } + if (t == "floor_divide") + { + int r = (int)af / (int)bf; + exprstack.push(std::to_string(r)); + } + if (t == "pow") + { + float r = pow(af, bf); + exprstack.push(std::to_string(r)); + } + } + else + { + std::string r = t + "(" + a + "," + b + ")"; + exprstack.push(r); + } + } + else if (t == "[") // list + { + std::vector elements; + while (!exprstack.empty()) + { + std::string a = exprstack.top(); + exprstack.pop(); + + elements.push_back(a); + } + + std::string r = "["; + for (int j = 0; j < (int)elements.size() - 1; j++) + { + r += elements[j]; + if (j + 1 != (int)elements.size()) + r += ","; + } + if (!elements.empty()) + { + r += elements[elements.size() - 1]; + } + r += "]"; + + exprstack.push(r); + } + else if (t[0] == '@') + { + exprstack.push(t); + } + else + { + // literal + exprstack.push(t); + } + } + + std::string r = exprstack.top(); + exprstack.pop(); + + // fprintf(stderr, "eval_expression return %s\n", r.c_str()); + + return r; +} + +static std::string canonicalize_arguments(const Operator* op, std::vector& inputs) +{ + std::string expr = op->params.at("expr").s; + + // split into tokens + std::vector tokens; + { + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '[') // list + { + t += ch; + tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + + t += ch; + tokens.push_back(t); + t.clear(); + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + } + + std::string r; + for (auto t : tokens) + { + if (t[0] == '@') + { + int input_index = std::stoi(t.substr(1)); + Operand* operand = op->inputs[input_index]; + + int new_input_index; + + auto it = std::find(inputs.begin(), inputs.end(), operand); + if (it == inputs.end()) + { + new_input_index = inputs.size(); + inputs.push_back(operand); + } + else + { + new_input_index = it - inputs.begin(); + } + r += std::string("@") + std::to_string(new_input_index); + } + else + { + r += t; + } + } + + // fprintf(stderr, "canonicalize_arguments return %s\n", r.c_str()); + + return r; +} + +void eval_expression(Graph& graph) +{ + for (Operator* op : graph.ops) + { + if (op->type != "pnnx.Expression") + continue; + + std::string expr_eval = eval_expression(op); + + op->params["expr"] = expr_eval; + + std::vector inputs; + std::string expr_canonicalize = canonicalize_arguments(op, inputs); + + op->params["expr"] = expr_canonicalize; + + for (auto r : op->inputs) + { + r->remove_consumer(op); + } + + for (auto r : inputs) + { + r->consumers.push_back(op); + } + + op->inputs = inputs; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eval_expression.h b/tools/pnnx/src/pass_level5/eval_expression.h new file mode 100644 index 000000000000..149ef82ce797 --- /dev/null +++ b/tools/pnnx/src/pass_level5/eval_expression.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void eval_expression(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp new file mode 100644 index 000000000000..062353945d16 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.cpp @@ -0,0 +1,133 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_channel_shuffle.h" + +#include "pass_level2.h" + +namespace pnnx { + +class fuse_channel_shuffle_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +pnnx.Expression op_0 1 1 input 13 expr=%expr +Tensor.view op_1 2 1 input 13 14 +pnnx.Expression op_2 1 1 input 15 expr=%expr2 +torch.transpose op_3 1 1 14 16 dim0=1 dim1=2 +Tensor.reshape op_4 2 1 16 15 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ChannelShuffle"; + } + + const char* name_str() const + { + return "channelshuffle"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + const std::string& expr = captured_params.at("expr").s; + const std::string& expr2 = captured_params.at("expr2").s; + + if (expr2 != "[int(size(@0,0)),-1,int(size(@0,2)),int(size(@0,3))]") + return false; + + int groups; + int nscan = sscanf(expr.c_str(), "[int(size(@0,0)),2,int(floor_divide(size(@0,1),%d)),int(size(@0,2)),int(size(@0,3))]", &groups); + if (nscan != 1) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& expr = captured_params.at("expr").s; + + int groups; + sscanf(expr.c_str(), "[int(size(@0,0)),2,int(floor_divide(size(@0,1),%d)),int(size(@0,2)),int(size(@0,3))]", &groups); + + op->params["groups"] = groups; + } +}; + +class fuse_channel_shuffle_pass_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +Tensor.view op_0 1 1 input 13 shape=%shape +torch.transpose op_1 1 1 13 14 dim0=1 dim1=2 +Tensor.reshape op_2 1 1 14 out shape=%shape2 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ChannelShuffle"; + } + + const char* name_str() const + { + return "channelshuffle"; + } + + bool match_captured_params_attrs(const std::map& captured_params) const + { + // (1,2,58,28,28) + // (1,-1,28,28) + const std::vector& shape = captured_params.at("shape").ai; + const std::vector& shape2 = captured_params.at("shape2").ai; + + if (shape[0] != 1 || shape2[0] != 1 || shape2[1] != -1 || shape2[2] != shape[3] || shape2[3] != shape[4]) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& shape = captured_params.at("shape").ai; + + int groups = shape[1]; + + op->params["groups"] = groups; + } +}; + +void fuse_channel_shuffle(Graph& graph) +{ + fuse_channel_shuffle_pass a; + fuse_channel_shuffle_pass_1 b; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_channel_shuffle.h b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.h new file mode 100644 index 000000000000..3257f9cda975 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_channel_shuffle.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_channel_shuffle(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_constant_expression.cpp b/tools/pnnx/src/pass_level5/fuse_constant_expression.cpp new file mode 100644 index 000000000000..7f04b6458af8 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_constant_expression.cpp @@ -0,0 +1,115 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_constant_expression.h" + +#include +#include +#include "pass_level2.h" + +namespace pnnx { + +void fuse_constant_expression(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "pnnx.Expression") + continue; + + if (op->inputs.size() != 0) + { + // dynamic expression + continue; + } + + Operand* expr_output = op->outputs[0]; + + std::vector new_consumers; + for (auto x : expr_output->consumers) + { + if (x->inputnames.empty()) + { + // x is not a function + new_consumers.push_back(x); + continue; + } + } + + if (new_consumers == expr_output->consumers) + continue; + + matched = true; + + Parameter ep = Parameter::parse_from_string(op->params.at("expr").s); + + for (auto& x : expr_output->consumers) + { + if (x->inputnames.empty()) + { + // x is not a function + continue; + } + + std::vector new_inputs; + std::vector new_inputnames; + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == expr_output) + { + // fuse constant + x->params[x->inputnames[j]] = ep; + } + else + { + new_inputs.push_back(x->inputs[j]); + new_inputnames.push_back(x->inputnames[j]); + } + } + + x->inputs = new_inputs; + x->inputnames = new_inputnames; + } + + expr_output->consumers = new_consumers; + + if (expr_output->consumers.empty()) + { + // delete expression and expr_output + + expr_output->producer = 0; + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), expr_output)); + delete expr_output; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op)); + delete op; + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_constant_expression.h b/tools/pnnx/src/pass_level5/fuse_constant_expression.h new file mode 100644 index 000000000000..bb6a5937bd7a --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_constant_expression.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_constant_expression(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_contiguous_view.cpp b/tools/pnnx/src/pass_level5/fuse_contiguous_view.cpp new file mode 100644 index 000000000000..e42a72c92273 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_contiguous_view.cpp @@ -0,0 +1,82 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_contiguous_view.h" + +#include "pass_level2.h" + +namespace pnnx { + +class fuse_contiguous_view_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Tensor.contiguous op_0 1 1 input a memory_format=* +Tensor.view op_1 1 1 a out shape=%shape +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } + + const char* name_str() const + { + return "view_shape"; + } +}; + +class fuse_contiguous_view_pass_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_1 0 1 input +pnnx.Input input_2 0 1 shape +Tensor.contiguous op_0 1 1 input a memory_format=* +Tensor.view op_1 2 1 a shape out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } + + const char* name_str() const + { + return "view_shape"; + } +}; + +void fuse_contiguous_view(Graph& graph) +{ + fuse_contiguous_view_pass a; + fuse_contiguous_view_pass_1 b; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_contiguous_view.h b/tools/pnnx/src/pass_level5/fuse_contiguous_view.h new file mode 100644 index 000000000000..33612c867a9c --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_contiguous_view.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_contiguous_view(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.cpp b/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.cpp new file mode 100644 index 000000000000..049b03f36cfa --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.cpp @@ -0,0 +1,136 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_conv2d_batchnorm2d.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_conv2d_batchnorm2d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.Conv2d op_0 1 1 input a in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +nn.BatchNorm2d op_1 1 1 a out num_features=%num_features eps=%eps affine=%affine @running_mean @running_var @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Conv2d"; + } + + const char* name_str() const + { + return "convbn2d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["stride"] = captured_params.at("stride"); + op->params["padding"] = captured_params.at("padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = true; + + // resolve merged conv2d weight and bias + int channels = captured_params.at("num_features").i; + float bn_eps = captured_params.at("eps").f; + bool has_bn_affine = captured_params.at("affine").b; + bool has_conv_bias = captured_params.at("bias").b; + + const float* bn_running_mean = (const float*)captured_attrs.at("op_1.running_mean").data.data(); + const float* bn_running_var = (const float*)captured_attrs.at("op_1.running_var").data.data(); + const float* bn_weight = has_bn_affine ? (const float*)captured_attrs.at("op_1.weight").data.data() : 0; + const float* bn_bias = has_bn_affine ? (const float*)captured_attrs.at("op_1.bias").data.data() : 0; + + // a = bias - slope * mean / sqrt(var + eps) + // b = slope / sqrt(var + eps) + // value = value * b + a + + std::vector a(channels); + std::vector b(channels); + for (int i = 0; i < channels; i++) + { + double sqrt_var = sqrt(bn_running_var[i] + bn_eps); + + if (has_bn_affine) + { + a[i] = bn_bias[i] - bn_weight[i] * bn_running_mean[i] / sqrt_var; + b[i] = bn_weight[i] / sqrt_var; + } + else + { + a[i] = -bn_running_mean[i] / sqrt_var; + b[i] = 1.f / sqrt_var; + } + } + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (has_conv_bias) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + else + { + // init bias as zero + op->attrs["bias"] = Attribute(); + op->attrs["bias"].type = 1; + op->attrs["bias"].shape = {channels}; + + op->attrs["bias"].data.resize(channels * sizeof(float)); + memset(op->attrs["bias"].data.data(), 0, channels * sizeof(float)); + } + + float* conv_weight = (float*)op->attrs["weight"].data.data(); + float* conv_bias = (float*)op->attrs["bias"].data.data(); + + const int outch = captured_params.at("out_channels").i; + const int weight_per_outch = op->attrs["weight"].data.size() / sizeof(float) / outch; + + for (int i = 0; i < channels; i++) + { + float* conv_weight_outch = conv_weight + weight_per_outch * i; + for (int j = 0; j < weight_per_outch; j++) + { + conv_weight_outch[j] *= b[i]; + } + + conv_bias[i] = conv_bias[i] * b[i] + a[i]; + } + } +}; + +void fuse_conv2d_batchnorm2d(Graph& graph) +{ + fuse_conv2d_batchnorm2d_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.h b/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.h new file mode 100644 index 000000000000..829aa5a6da2e --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_conv2d_batchnorm2d.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_conv2d_batchnorm2d(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.cpp b/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.cpp new file mode 100644 index 000000000000..1be27d597024 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.cpp @@ -0,0 +1,154 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_convtranspose2d_batchnorm2d.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_convtranspose2d_batchnorm2d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.ConvTranspose2d op_0 1 1 input a in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride output_padding=%output_padding padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +nn.BatchNorm2d op_1 1 1 a out num_features=%num_features eps=%eps affine=%affine @running_mean @running_var @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.ConvTranspose2d"; + } + + const char* name_str() const + { + return "convtransposebn2d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["in_channels"] = captured_params.at("in_channels"); + op->params["out_channels"] = captured_params.at("out_channels"); + op->params["kernel_size"] = captured_params.at("kernel_size"); + op->params["stride"] = captured_params.at("stride"); + op->params["output_padding"] = captured_params.at("output_padding"); + op->params["padding"] = captured_params.at("padding"); + op->params["dilation"] = captured_params.at("dilation"); + op->params["groups"] = captured_params.at("groups"); + op->params["bias"] = true; + + // resolve merged convtranspose2d weight and bias + int channels = captured_params.at("num_features").i; + float bn_eps = captured_params.at("eps").f; + bool has_bn_affine = captured_params.at("affine").b; + bool has_convtranspose_bias = captured_params.at("bias").b; + + const float* bn_running_mean = (const float*)captured_attrs.at("op_1.running_mean").data.data(); + const float* bn_running_var = (const float*)captured_attrs.at("op_1.running_var").data.data(); + const float* bn_weight = has_bn_affine ? (const float*)captured_attrs.at("op_1.weight").data.data() : 0; + const float* bn_bias = has_bn_affine ? (const float*)captured_attrs.at("op_1.bias").data.data() : 0; + + // a = bias - slope * mean / sqrt(var + eps) + // b = slope / sqrt(var + eps) + // value = value * b + a + + std::vector a(channels); + std::vector b(channels); + for (int i = 0; i < channels; i++) + { + double sqrt_var = sqrt(bn_running_var[i] + bn_eps); + + if (has_bn_affine) + { + a[i] = bn_bias[i] - bn_weight[i] * bn_running_mean[i] / sqrt_var; + b[i] = bn_weight[i] / sqrt_var; + } + else + { + a[i] = -bn_running_mean[i] / sqrt_var; + b[i] = 1.f / sqrt_var; + } + } + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (has_convtranspose_bias) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + else + { + // init bias as zero + op->attrs["bias"] = Attribute(); + op->attrs["bias"].type = 1; + op->attrs["bias"].shape = {channels}; + + op->attrs["bias"].data.resize(channels * sizeof(float)); + memset(op->attrs["bias"].data.data(), 0, channels * sizeof(float)); + } + + float* conv_weight = (float*)op->attrs["weight"].data.data(); + float* conv_bias = (float*)op->attrs["bias"].data.data(); + + // group-inch/group-outch/group-kh-kw + const int inch = captured_params.at("in_channels").i; + const int outch = captured_params.at("out_channels").i; + const int groups = captured_params.at("groups").i; + const int kh = captured_params.at("kernel_size").ai[0]; + const int kw = captured_params.at("kernel_size").ai[1]; + + const int outch_g = outch / groups; + const int inch_g = inch / groups; + const int maxk = kh * kw; + + for (int g = 0; g < groups; g++) + { + float* wg = conv_weight + g * inch_g * outch_g * maxk; + for (int i = 0; i < inch_g; i++) + { + for (int j = 0; j < outch_g; j++) + { + for (int k = 0; k < maxk; k++) + { + wg[(i * outch_g + j) * maxk + k] *= b[g * outch_g + j]; + } + } + } + } + + for (int i = 0; i < channels; i++) + { + conv_bias[i] = conv_bias[i] * b[i] + a[i]; + } + } +}; + +void fuse_convtranspose2d_batchnorm2d(Graph& graph) +{ + fuse_convtranspose2d_batchnorm2d_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.h b/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.h new file mode 100644 index 000000000000..854b72a250ff --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_convtranspose2d_batchnorm2d.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_convtranspose2d_batchnorm2d(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.cpp b/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.cpp new file mode 100644 index 000000000000..f014437351a9 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.cpp @@ -0,0 +1,130 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_linear_batchnorm1d.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_linear_batchnorm1d_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input a in_features=%in_features out_features=%out_features bias=%bias @weight @bias +nn.BatchNorm1d op_1 1 1 a out num_features=%num_features eps=%eps affine=%affine @running_mean @running_var @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.Linear"; + } + + const char* name_str() const + { + return "linearbn1d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["in_features"] = captured_params.at("in_features"); + op->params["out_features"] = captured_params.at("out_features"); + op->params["bias"] = true; + + // resolve merged linear weight and bias + int channels = captured_params.at("out_features").i; + float bn_eps = captured_params.at("eps").f; + bool has_bn_affine = captured_params.at("affine").b; + bool has_conv_bias = captured_params.at("bias").b; + + const float* bn_running_mean = (const float*)captured_attrs.at("op_1.running_mean").data.data(); + const float* bn_running_var = (const float*)captured_attrs.at("op_1.running_var").data.data(); + const float* bn_weight = has_bn_affine ? (const float*)captured_attrs.at("op_1.weight").data.data() : 0; + const float* bn_bias = has_bn_affine ? (const float*)captured_attrs.at("op_1.bias").data.data() : 0; + + // a = bias - slope * mean / sqrt(var + eps) + // b = slope / sqrt(var + eps) + // value = value * b + a + + std::vector a(channels); + std::vector b(channels); + for (int i = 0; i < channels; i++) + { + double sqrt_var = sqrt(bn_running_var[i] + bn_eps); + + if (has_bn_affine) + { + a[i] = bn_bias[i] - bn_weight[i] * bn_running_mean[i] / sqrt_var; + b[i] = bn_weight[i] / sqrt_var; + } + else + { + a[i] = -bn_running_mean[i] / sqrt_var; + b[i] = 1.f / sqrt_var; + } + } + + op->attrs["weight"] = captured_attrs.at("op_0.weight"); + + if (has_conv_bias) + { + op->attrs["bias"] = captured_attrs.at("op_0.bias"); + } + else + { + // init bias as zero + op->attrs["bias"] = Attribute(); + op->attrs["bias"].type = 1; + op->attrs["bias"].shape = {channels}; + + op->attrs["bias"].data.resize(channels * sizeof(float)); + memset(op->attrs["bias"].data.data(), 0, channels * sizeof(float)); + } + + float* conv_weight = (float*)op->attrs["weight"].data.data(); + float* conv_bias = (float*)op->attrs["bias"].data.data(); + + const int weight_per_outch = op->params["in_features"].i; + + for (int i = 0; i < channels; i++) + { + float* conv_weight_outch = conv_weight + weight_per_outch * i; + for (int j = 0; j < weight_per_outch; j++) + { + conv_weight_outch[j] *= b[i]; + } + + conv_bias[i] = conv_bias[i] * b[i] + a[i]; + } + } +}; + +void fuse_linear_batchnorm1d(Graph& graph) +{ + fuse_linear_batchnorm1d_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.h b/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.h new file mode 100644 index 000000000000..b04e03332eb3 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_linear_batchnorm1d.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_linear_batchnorm1d(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_slice_indices.cpp b/tools/pnnx/src/pass_level5/fuse_slice_indices.cpp new file mode 100644 index 000000000000..22be6e5b46f8 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_slice_indices.cpp @@ -0,0 +1,323 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_slice_indices.h" + +#include +#include "pass_level2.h" + +namespace pnnx { + +class fuse_slice_indices_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +Tensor.slice op_0 1 1 input a dim=%dim0 end=%end0 start=%start0 step=%step0 +Tensor.slice op_1 1 1 a b dim=%dim1 end=%end1 start=%start1 step=%step1 +Tensor.slice op_2 1 1 b c dim=%dim2 end=%end2 start=%start2 step=%step2 +Tensor.slice op_3 1 1 c d dim=%dim3 end=%end3 start=%start3 step=%step3 +Tensor.slice op_4 1 1 d out dim=%dim4 end=%end4 start=%start4 step=%step4 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.slice"; + } + + const char* name_str() const + { + return "slice"; + } + + bool match(const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + int dim2 = captured_params.at("dim2").i; + int dim3 = captured_params.at("dim3").i; + int dim4 = captured_params.at("dim4").i; + + return dim0 < dim1 && dim1 < dim2 && dim2 < dim3 && dim3 < dim4; + } + + void write(Operator* op, const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + int dim2 = captured_params.at("dim2").i; + int dim3 = captured_params.at("dim3").i; + int dim4 = captured_params.at("dim4").i; + + int start0 = captured_params.at("start0").i; + int start1 = captured_params.at("start1").i; + int start2 = captured_params.at("start2").i; + int start3 = captured_params.at("start3").i; + int start4 = captured_params.at("start4").i; + + int end0 = captured_params.at("end0").i; + int end1 = captured_params.at("end1").i; + int end2 = captured_params.at("end2").i; + int end3 = captured_params.at("end3").i; + int end4 = captured_params.at("end4").i; + + int step0 = captured_params.at("step0").i; + int step1 = captured_params.at("step1").i; + int step2 = captured_params.at("step2").i; + int step3 = captured_params.at("step3").i; + int step4 = captured_params.at("step4").i; + + op->params["dims"] = Parameter{dim0, dim1, dim2, dim3, dim4}; + op->params["starts"] = Parameter{start0, start1, start2, start3, start4}; + op->params["ends"] = Parameter{end0, end1, end2, end3, end4}; + op->params["steps"] = Parameter{step0, step1, step2, step3, step4}; + } +}; + +class fuse_slice_indices_pass_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +Tensor.slice op_0 1 1 input a dim=%dim0 end=%end0 start=%start0 step=%step0 +Tensor.slice op_1 1 1 a b dim=%dim1 end=%end1 start=%start1 step=%step1 +Tensor.slice op_2 1 1 b c dim=%dim2 end=%end2 start=%start2 step=%step2 +Tensor.slice op_3 1 1 c out dim=%dim3 end=%end3 start=%start3 step=%step3 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.slice"; + } + + const char* name_str() const + { + return "slice"; + } + + bool match(const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + int dim2 = captured_params.at("dim2").i; + int dim3 = captured_params.at("dim3").i; + + return dim0 < dim1 && dim1 < dim2 && dim2 < dim3; + } + + void write(Operator* op, const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + int dim2 = captured_params.at("dim2").i; + int dim3 = captured_params.at("dim3").i; + + int start0 = captured_params.at("start0").i; + int start1 = captured_params.at("start1").i; + int start2 = captured_params.at("start2").i; + int start3 = captured_params.at("start3").i; + + int end0 = captured_params.at("end0").i; + int end1 = captured_params.at("end1").i; + int end2 = captured_params.at("end2").i; + int end3 = captured_params.at("end3").i; + + int step0 = captured_params.at("step0").i; + int step1 = captured_params.at("step1").i; + int step2 = captured_params.at("step2").i; + int step3 = captured_params.at("step3").i; + + op->params["dims"] = Parameter{dim0, dim1, dim2, dim3}; + op->params["starts"] = Parameter{start0, start1, start2, start3}; + op->params["ends"] = Parameter{end0, end1, end2, end3}; + op->params["steps"] = Parameter{step0, step1, step2, step3}; + } +}; + +class fuse_slice_indices_pass_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +Tensor.slice op_0 1 1 input a dim=%dim0 end=%end0 start=%start0 step=%step0 +Tensor.slice op_1 1 1 a b dim=%dim1 end=%end1 start=%start1 step=%step1 +Tensor.slice op_2 1 1 b out dim=%dim2 end=%end2 start=%start2 step=%step2 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.slice"; + } + + const char* name_str() const + { + return "slice"; + } + + bool match(const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + int dim2 = captured_params.at("dim2").i; + + return dim0 < dim1 && dim1 < dim2; + } + + void write(Operator* op, const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + int dim2 = captured_params.at("dim2").i; + + int start0 = captured_params.at("start0").i; + int start1 = captured_params.at("start1").i; + int start2 = captured_params.at("start2").i; + + int end0 = captured_params.at("end0").i; + int end1 = captured_params.at("end1").i; + int end2 = captured_params.at("end2").i; + + int step0 = captured_params.at("step0").i; + int step1 = captured_params.at("step1").i; + int step2 = captured_params.at("step2").i; + + op->params["dims"] = Parameter{dim0, dim1, dim2}; + op->params["starts"] = Parameter{start0, start1, start2}; + op->params["ends"] = Parameter{end0, end1, end2}; + op->params["steps"] = Parameter{step0, step1, step2}; + } +}; + +class fuse_slice_indices_pass_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Tensor.slice op_0 1 1 input a dim=%dim0 end=%end0 start=%start0 step=%step0 +Tensor.slice op_1 1 1 a out dim=%dim1 end=%end1 start=%start1 step=%step1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.slice"; + } + + const char* name_str() const + { + return "slice"; + } + + bool match(const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + + return dim0 < dim1; + } + + void write(Operator* op, const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + + int start0 = captured_params.at("start0").i; + int start1 = captured_params.at("start1").i; + + int end0 = captured_params.at("end0").i; + int end1 = captured_params.at("end1").i; + + int step0 = captured_params.at("step0").i; + int step1 = captured_params.at("step1").i; + + op->params["dims"] = Parameter{dim0, dim1}; + op->params["starts"] = Parameter{start0, start1}; + op->params["ends"] = Parameter{end0, end1}; + op->params["steps"] = Parameter{step0, step1}; + } +}; + +class fuse_slice_indices_pass_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Tensor.slice op_0 1 1 input out dim=%dim0 end=%end0 start=%start0 step=%step0 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.slice"; + } + + const char* name_str() const + { + return "slice"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int dim0 = captured_params.at("dim0").i; + int start0 = captured_params.at("start0").i; + int end0 = captured_params.at("end0").i; + int step0 = captured_params.at("step0").i; + + op->params["dims"] = Parameter{dim0}; + op->params["starts"] = Parameter{start0}; + op->params["ends"] = Parameter{end0}; + op->params["steps"] = Parameter{step0}; + } +}; + +void fuse_slice_indices(Graph& graph) +{ + fuse_slice_indices_pass a; + fuse_slice_indices_pass_1 b; + fuse_slice_indices_pass_2 c; + fuse_slice_indices_pass_3 d; + fuse_slice_indices_pass_4 e; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); + pnnx_graph_rewrite(graph, &e, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_slice_indices.h b/tools/pnnx/src/pass_level5/fuse_slice_indices.h new file mode 100644 index 000000000000..ca8dfa32da81 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_slice_indices.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_slice_indices(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp b/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp new file mode 100644 index 000000000000..c832353be229 --- /dev/null +++ b/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp @@ -0,0 +1,242 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "unroll_rnn_op.h" + +#include + +namespace pnnx { + +void unroll_rnn_op(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "nn.RNN" && op->type != "nn.LSTM" && op->type != "nn.GRU") + continue; + + int num_layers = op->params["num_layers"].i; + if (num_layers == 1) + continue; + + matched = true; + + bool has_input_hidden = op->inputs.size() >= 2; + bool has_input_cell = op->inputs.size() == 3; + bool has_output_hidden = op->outputs.size() >= 2; + bool has_output_cell = op->outputs.size() == 3; + const int hidden_size = op->params["hidden_size"].i; + bool has_bias = op->params["bias"].b; + bool is_bidirectional = op->params["bidirectional"].b; + + std::vector input_hiddens(num_layers); + std::vector input_cells(num_layers); + std::vector output_hiddens(num_layers); + std::vector output_cells(num_layers); + + // slice input hidden cell + if (has_input_hidden) + { + std::string opname = op->name + "_chunk_in_hidden"; + + Operator* op1 = graph.new_operator_before("torch.chunk", opname, op); + + op1->params["chunks"] = num_layers; + op1->params["dim"] = 0; + + op1->inputs.push_back(op->inputs[1]); + op->inputs[1]->remove_consumer(op); + op->inputs[1]->consumers.push_back(op1); + + for (int j = 0; j < num_layers; j++) + { + Operand* r0 = graph.new_operand(op1->name + "_in_hidden_" + std::to_string(j)); + r0->producer = op1; + op1->outputs.push_back(r0); + + input_hiddens[j] = r0; + } + } + if (has_input_cell) + { + std::string opname = op->name + "_chunk_in_cell"; + + Operator* op1 = graph.new_operator_before("torch.chunk", opname, op); + + op1->params["chunks"] = num_layers; + op1->params["dim"] = 0; + + op1->inputs.push_back(op->inputs[2]); + op->inputs[2]->remove_consumer(op); + op->inputs[2]->consumers.push_back(op1); + + for (int j = 0; j < num_layers; j++) + { + Operand* r0 = graph.new_operand(op1->name + "_in_cell_" + std::to_string(j)); + r0->producer = op1; + op1->outputs.push_back(r0); + + input_cells[j] = r0; + } + } + + // unroll + std::vector unrolled_ops(num_layers); + for (int j = 0; j < num_layers; j++) + { + std::string opname = op->name + "_unroll_" + std::to_string(j); + + Operator* op1 = graph.new_operator_before(op->type, opname, op); + + op1->params = op->params; + op1->params["num_layers"] = 1; + + // link + if (j == 0) + { + op1->inputs.push_back(op->inputs[0]); + op1->inputs[0]->remove_consumer(op); + op1->inputs[0]->consumers.push_back(op1); + } + else + { + op1->params["input_size"] = is_bidirectional ? hidden_size * 2 : hidden_size; + + op1->inputs.push_back(unrolled_ops[j - 1]->outputs[0]); + op1->inputs[0]->consumers.push_back(op1); + } + + if (has_input_hidden) + { + op1->inputs.push_back(input_hiddens[j]); + op1->inputs[1]->consumers.push_back(op1); + } + if (has_input_cell) + { + op1->inputs.push_back(input_cells[j]); + op1->inputs[2]->consumers.push_back(op1); + } + + if (j == num_layers - 1) + { + op1->outputs.push_back(op->outputs[0]); + op1->outputs[0]->producer = op1; + } + else + { + Operand* r0 = graph.new_operand(op1->name + "_out"); + r0->producer = op1; + op1->outputs.push_back(r0); + } + + if (has_output_hidden) + { + Operand* r1 = graph.new_operand(op1->name + "_out_hidden"); + r1->producer = op1; + op1->outputs.push_back(r1); + + output_hiddens[j] = r1; + } + if (has_output_cell) + { + Operand* r1 = graph.new_operand(op1->name + "_out_cell"); + r1->producer = op1; + op1->outputs.push_back(r1); + + output_cells[j] = r1; + } + + op1->attrs["weight_hh_l0"] = op->attrs["weight_hh_l" + std::to_string(j)]; + op1->attrs["weight_ih_l0"] = op->attrs["weight_ih_l" + std::to_string(j)]; + + if (has_bias) + { + op1->attrs["bias_hh_l0"] = op->attrs["bias_hh_l" + std::to_string(j)]; + op1->attrs["bias_ih_l0"] = op->attrs["bias_ih_l" + std::to_string(j)]; + } + + if (is_bidirectional) + { + op1->attrs["weight_hh_l0_reverse"] = op->attrs["weight_hh_l" + std::to_string(j) + "_reverse"]; + op1->attrs["weight_ih_l0_reverse"] = op->attrs["weight_ih_l" + std::to_string(j) + "_reverse"]; + + if (has_bias) + { + op1->attrs["bias_hh_l0_reverse"] = op->attrs["bias_hh_l" + std::to_string(j) + "_reverse"]; + op1->attrs["bias_ih_l0_reverse"] = op->attrs["bias_ih_l" + std::to_string(j) + "_reverse"]; + } + } + + unrolled_ops[j] = op1; + } + + // concat output hidden cell + if (has_output_hidden) + { + std::string opname = op->name + "_cat_out_hidden"; + + Operator* op1 = graph.new_operator_before("torch.cat", opname, op); + + op1->params["dim"] = 0; + + for (int j = 0; j < num_layers; j++) + { + Operand* r0 = output_hiddens[j]; + r0->consumers.push_back(op1); + op1->inputs.push_back(r0); + } + + op1->outputs.push_back(op->outputs[1]); + op1->outputs[0]->producer = op1; + } + if (has_output_cell) + { + std::string opname = op->name + "_cat_out_cell"; + + Operator* op1 = graph.new_operator_before("torch.cat", opname, op); + + op1->params["dim"] = 0; + + for (int j = 0; j < num_layers; j++) + { + Operand* r0 = output_cells[j]; + r0->consumers.push_back(op1); + op1->inputs.push_back(r0); + } + + op1->outputs.push_back(op->outputs[2]); + op1->outputs[0]->producer = op1; + } + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op)); + + delete op; + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/unroll_rnn_op.h b/tools/pnnx/src/pass_level5/unroll_rnn_op.h new file mode 100644 index 000000000000..a3d57a84f041 --- /dev/null +++ b/tools/pnnx/src/pass_level5/unroll_rnn_op.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level5.h" + +namespace pnnx { + +void unroll_rnn_op(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn.cpp b/tools/pnnx/src/pass_ncnn.cpp new file mode 100644 index 000000000000..2dffc2c2e07d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn.cpp @@ -0,0 +1,109 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +#include "pass_ncnn/convert_attribute.h" +#include "pass_ncnn/convert_custom_op.h" +#include "pass_ncnn/convert_input.h" +#include "pass_ncnn/convert_torch_cat.h" +#include "pass_ncnn/convert_torch_chunk.h" +#include "pass_ncnn/convert_torch_split.h" +#include "pass_ncnn/eliminate_output.h" +#include "pass_ncnn/expand_expression.h" +#include "pass_ncnn/insert_split.h" +#include "pass_ncnn/chain_multi_output.h" +#include "pass_ncnn/solve_batch_index.h" + +#include "pass_ncnn/eliminate_noop.h" +#include "pass_ncnn/fuse_convolution_activation.h" +#include "pass_ncnn/fuse_convolution1d_activation.h" +#include "pass_ncnn/fuse_convolutiondepthwise_activation.h" +#include "pass_ncnn/fuse_convolutiondepthwise1d_activation.h" +#include "pass_ncnn/fuse_deconvolution_activation.h" +#include "pass_ncnn/fuse_deconvolutiondepthwise_activation.h" +#include "pass_ncnn/fuse_innerproduct_activation.h" + +#include "pass_level4/dead_code_elimination.h" +#include "pass_level4/canonicalize.h" +#include "pass_level5/unroll_rnn_op.h" + +namespace pnnx { + +static std::map > g_global_pnnx_ncnn_graph_rewriter_passes; + +NcnnGraphRewriterPassRegister::NcnnGraphRewriterPassRegister(const GraphRewriterPass* _pass, int priority) + : pass(_pass) +{ + if (g_global_pnnx_ncnn_graph_rewriter_passes.find(priority) == g_global_pnnx_ncnn_graph_rewriter_passes.end()) + { + g_global_pnnx_ncnn_graph_rewriter_passes[priority] = std::vector(); + } + + g_global_pnnx_ncnn_graph_rewriter_passes[priority].push_back(pass); +} + +NcnnGraphRewriterPassRegister::~NcnnGraphRewriterPassRegister() +{ + delete pass; +} + +void pass_ncnn(Graph& g) +{ + unroll_rnn_op(g); + + ncnn::expand_expression(g); + + ncnn::chain_multi_output(g); + + ncnn::solve_batch_index(g); + + int opindex = 0; + for (auto x : g_global_pnnx_ncnn_graph_rewriter_passes) + { + for (auto rewriter : x.second) + { + pnnx_graph_rewrite(g, rewriter, opindex); + } + } + + ncnn::convert_torch_cat(g); + ncnn::convert_torch_chunk(g); + ncnn::convert_torch_split(g); + + ncnn::insert_split(g); + + ncnn::eliminate_noop(g); + ncnn::fuse_convolution_activation(g); + ncnn::fuse_convolution1d_activation(g); + ncnn::fuse_convolutiondepthwise_activation(g); + ncnn::fuse_convolutiondepthwise1d_activation(g); + ncnn::fuse_deconvolution_activation(g); + ncnn::fuse_deconvolutiondepthwise_activation(g); + ncnn::fuse_innerproduct_activation(g); + + dead_code_elimination(g); + + canonicalize(g); + + ncnn::convert_custom_op(g); + + ncnn::convert_attribute(g); + + ncnn::convert_input(g); + + ncnn::eliminate_output(g); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn.h b/tools/pnnx/src/pass_ncnn.h new file mode 100644 index 000000000000..61510cf2e24f --- /dev/null +++ b/tools/pnnx/src/pass_ncnn.h @@ -0,0 +1,39 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_PASS_NCNN_H +#define PNNX_PASS_NCNN_H + +#include "ir.h" + +#include "pass_level2.h" + +namespace pnnx { + +class NcnnGraphRewriterPassRegister +{ +public: + NcnnGraphRewriterPassRegister(const GraphRewriterPass* pass, int priority); + ~NcnnGraphRewriterPassRegister(); + const GraphRewriterPass* pass; +}; + +#define REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(CLASS, PRIORITY) \ + static NcnnGraphRewriterPassRegister g_global_pnnx_ncnngraphrewriterpass_##CLASS##_register(new CLASS, PRIORITY); + +void pass_ncnn(Graph& g); + +} // namespace pnnx + +#endif // PNNX_PASS_NCNN_H diff --git a/tools/pnnx/src/pass_ncnn/F_adaptive_avg_pool2d.cpp b/tools/pnnx/src/pass_ncnn/F_adaptive_avg_pool2d.cpp new file mode 100644 index 000000000000..9f4737e974b5 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_adaptive_avg_pool2d.cpp @@ -0,0 +1,89 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_adaptive_avg_pool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.adaptive_avg_pool2d op_0 1 1 input out output_size=(1,1) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "gap"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 1; + op->params["4"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d, 20) + +class F_adaptive_avg_pool2d_n : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.adaptive_avg_pool2d op_0 1 1 input out output_size=%output_size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "aap"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 1; + op->params["7"] = 1; + op->params["8"] = captured_params.at("output_size").ai[1]; + op->params["18"] = captured_params.at("output_size").ai[0]; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d_n, 21) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_adaptive_max_pool2d.cpp b/tools/pnnx/src/pass_ncnn/F_adaptive_max_pool2d.cpp new file mode 100644 index 000000000000..d514ed673939 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_adaptive_max_pool2d.cpp @@ -0,0 +1,89 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_adaptive_max_pool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.adaptive_max_pool2d op_0 1 1 input out output_size=(1,1) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "gmp"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 0; + op->params["4"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d, 20) + +class F_adaptive_max_pool2d_n : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.adaptive_max_pool2d op_0 1 1 input out output_size=%output_size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "amp"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 0; + op->params["7"] = 1; + op->params["8"] = captured_params.at("output_size").ai[1]; + op->params["18"] = captured_params.at("output_size").ai[0]; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d_n, 21) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_batch_norm.cpp b/tools/pnnx/src/pass_ncnn/F_batch_norm.cpp new file mode 100644 index 000000000000..90fd2c661673 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_batch_norm.cpp @@ -0,0 +1,131 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_batch_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_mean 0 1 running_mean @qwq +pnnx.Attribute op_var 0 1 running_var @qwq +F.batch_norm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "BatchNorm"; + } + + const char* name_str() const + { + return "bn"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute running_mean; + Attribute running_var; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 8) == "op_mean.") + running_mean = x.second; + if (x.first.substr(0, 7) == "op_var.") + running_var = x.second; + } + + op->params["0"] = running_mean.shape[0]; + op->params["1"] = captured_params.at("eps"); + + const int channels = running_mean.shape[0]; + + op->attrs["0"] = Attribute({channels}, std::vector(channels, 1.f)); + op->attrs["1"] = running_mean; + op->attrs["2"] = running_var; + op->attrs["3"] = Attribute({channels}, std::vector(channels, 0.f)); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_batch_norm, 20) + +class F_batch_norm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +pnnx.Attribute op_mean 0 1 running_mean @qwq +pnnx.Attribute op_var 0 1 running_var @qwq +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "BatchNorm"; + } + + const char* name_str() const + { + return "bn"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute running_mean; + Attribute running_var; + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 8) == "op_mean.") + running_mean = x.second; + if (x.first.substr(0, 7) == "op_var.") + running_var = x.second; + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["0"] = running_mean.shape[0]; + op->params["1"] = captured_params.at("eps"); + + op->attrs["0"] = weight; + op->attrs["1"] = running_mean; + op->attrs["2"] = running_var; + op->attrs["3"] = bias; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_batch_norm_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_conv2d.cpp b/tools/pnnx/src/pass_ncnn/F_conv2d.cpp new file mode 100644 index 000000000000..bb1c05c4ca53 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_conv2d.cpp @@ -0,0 +1,247 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_conv2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution"; + } + + const char* name_str() const + { + return "conv2d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + op->params["0"] = weight.shape[0]; + op->params["1"] = weight.shape[2]; + op->params["11"] = weight.shape[3]; + op->params["2"] = captured_params.at("dilation").ai[1]; + op->params["12"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[1]; + op->params["13"] = captured_params.at("stride").ai[0]; + op->params["4"] = captured_params.at("padding").ai[1]; + op->params["14"] = captured_params.at("padding").ai[0]; + op->params["5"] = 0; + op->params["6"] = (int)(weight.data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = weight; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d, 20) + +class F_conv2d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution"; + } + + const char* name_str() const + { + return "conv2d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["0"] = weight.shape[0]; + op->params["1"] = weight.shape[2]; + op->params["11"] = weight.shape[3]; + op->params["2"] = captured_params.at("dilation").ai[1]; + op->params["12"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[1]; + op->params["13"] = captured_params.at("stride").ai[0]; + op->params["4"] = captured_params.at("padding").ai[1]; + op->params["14"] = captured_params.at("padding").ai[0]; + op->params["5"] = 1; + op->params["6"] = (int)(weight.data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = weight; + op->attrs["2"] = bias; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_1, 20) + +class F_conv2d_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.conv2d op_0 2 1 input weight out bias=None stride=%stride padding=%padding dilation=%dilation groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise"; + } + + const char* name_str() const + { + return "convdw2d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + op->params["0"] = weight.shape[0]; + op->params["1"] = weight.shape[2]; + op->params["11"] = weight.shape[3]; + op->params["2"] = captured_params.at("dilation").ai[1]; + op->params["12"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[1]; + op->params["13"] = captured_params.at("stride").ai[0]; + op->params["4"] = captured_params.at("padding").ai[1]; + op->params["14"] = captured_params.at("padding").ai[0]; + op->params["5"] = 0; + op->params["6"] = (int)(weight.data.size() / sizeof(float)); + op->params["7"] = captured_params.at("groups"); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = weight; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_2, 21) + +class F_conv2d_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.conv2d op_0 3 1 input weight bias out stride=%stride padding=%padding dilation=%dilation groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution"; + } + + const char* name_str() const + { + return "conv2d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["0"] = weight.shape[0]; + op->params["1"] = weight.shape[2]; + op->params["11"] = weight.shape[3]; + op->params["2"] = captured_params.at("dilation").ai[1]; + op->params["12"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[1]; + op->params["13"] = captured_params.at("stride").ai[0]; + op->params["4"] = captured_params.at("padding").ai[1]; + op->params["14"] = captured_params.at("padding").ai[0]; + op->params["5"] = 1; + op->params["6"] = (int)(weight.data.size() / sizeof(float)); + op->params["7"] = captured_params.at("groups"); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = weight; + op->attrs["2"] = bias; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_conv2d_3, 21) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_elu.cpp b/tools/pnnx/src/pass_ncnn/F_elu.cpp new file mode 100644 index 000000000000..fc55e80cc7c7 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_elu.cpp @@ -0,0 +1,54 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_elu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.elu op_0 1 1 input out alpha=%alpha +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ELU"; + } + + const char* name_str() const + { + return "elu"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("alpha"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_elu, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_gelu.cpp b/tools/pnnx/src/pass_ncnn/F_gelu.cpp new file mode 100644 index 000000000000..74249bf38314 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_gelu.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_gelu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.gelu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GELU"; + } + + const char* name_str() const + { + return "gelu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_gelu, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_group_norm.cpp b/tools/pnnx/src/pass_ncnn/F_group_norm.cpp new file mode 100644 index 000000000000..7aecbf238558 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_group_norm.cpp @@ -0,0 +1,114 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_group_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.group_norm op_0 1 1 input out weight=None bias=None num_groups=%num_groups eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GroupNorm"; + } + + const char* name_str() const + { + return "gn"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int input_rank = op->inputs[0]->shape.size(); + + if (input_rank <= 2) + { + fprintf(stderr, "group_norm not possible for %d-rank tensor\n", input_rank); + return; + } + + op->params["0"] = captured_params.at("num_groups"); + op->params["1"] = op->inputs[0]->shape[1]; + op->params["2"] = captured_params.at("eps"); + op->params["3"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_group_norm, 20) + +class F_group_norm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.group_norm op_0 3 1 input weight bias out num_groups=%num_groups eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GroupNorm"; + } + + const char* name_str() const + { + return "gn"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["0"] = captured_params.at("num_groups"); + op->params["1"] = weight.shape[0]; + op->params["2"] = captured_params.at("eps"); + op->params["3"] = 1; + + op->attrs["0"] = weight; + op->attrs["1"] = bias; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_group_norm_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_hardsigmoid.cpp b/tools/pnnx/src/pass_ncnn/F_hardsigmoid.cpp new file mode 100644 index 000000000000..5d0bf0f54697 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_hardsigmoid.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_hardsigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.hardsigmoid op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "HardSigmoid"; + } + + const char* name_str() const + { + return "hsigmoid"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 1.f / 6; + op->params["1"] = 0.5f; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_hardsigmoid, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_hardswish.cpp b/tools/pnnx/src/pass_ncnn/F_hardswish.cpp new file mode 100644 index 000000000000..b1294136ed49 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_hardswish.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_hardswish : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.hardswish op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "HardSwish"; + } + + const char* name_str() const + { + return "hswish"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 1.f / 6; + op->params["1"] = 0.5f; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_hardswish, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_hardtanh.cpp b/tools/pnnx/src/pass_ncnn/F_hardtanh.cpp new file mode 100644 index 000000000000..585b2e7ed181 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_hardtanh.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_hardtanh : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.hardtanh op_0 1 1 input out min_val=%0 max_val=%1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Clip"; + } + + const char* name_str() const + { + return "htanh"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_hardtanh, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_instance_norm.cpp b/tools/pnnx/src/pass_ncnn/F_instance_norm.cpp new file mode 100644 index 000000000000..003b49bc756c --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_instance_norm.cpp @@ -0,0 +1,112 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_instance_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.instance_norm op_0 1 1 input out weight=None bias=None running_mean=None running_var=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InstanceNorm"; + } + + const char* name_str() const + { + return "in"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int input_rank = op->inputs[0]->shape.size(); + + if (input_rank <= 2) + { + fprintf(stderr, "instance_norm not possible for %d-rank tensor\n", input_rank); + return; + } + + op->params["0"] = op->inputs[0]->shape[1]; + op->params["1"] = captured_params.at("eps"); + op->params["2"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_instance_norm, 20) + +class F_instance_norm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.instance_norm op_0 3 1 input weight bias out running_mean=None running_var=None eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InstanceNorm"; + } + + const char* name_str() const + { + return "in"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["0"] = weight.shape[0]; + op->params["1"] = captured_params.at("eps"); + op->params["2"] = 1; + + op->attrs["0"] = weight; + op->attrs["1"] = bias; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_instance_norm_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_interpolate.cpp b/tools/pnnx/src/pass_ncnn/F_interpolate.cpp new file mode 100644 index 000000000000..85c3257c662f --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_interpolate.cpp @@ -0,0 +1,151 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_interpolate : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.interpolate op_0 1 1 input out mode=%mode recompute_scale_factor=* scale_factor=%scale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + std::vector scale_factor; + if (captured_params.at("scale_factor").type == 3) + { + scale_factor.push_back(captured_params.at("scale_factor").f); + } + else + { + scale_factor = captured_params.at("scale_factor").af; + } + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (scale_factor.size() == 1) + { + op->params["1"] = 1.f; + op->params["2"] = scale_factor[0]; + } + else if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else + { + fprintf(stderr, "unsupported interpolate scale_factor\n"); + } + + op->params["6"] = 0; // align_corners + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_interpolate, 20) + +class F_interpolate_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.interpolate op_0 1 1 input out align_corners=%align_corners mode=%mode recompute_scale_factor=* scale_factor=%scale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "interpolate"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + std::vector scale_factor; + if (captured_params.at("scale_factor").type == 3) + { + scale_factor.push_back(captured_params.at("scale_factor").f); + } + else + { + scale_factor = captured_params.at("scale_factor").af; + } + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (scale_factor.size() == 1) + { + op->params["1"] = 1.f; + op->params["2"] = scale_factor[0]; + } + else if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else + { + fprintf(stderr, "unsupported interpolate scale_factor\n"); + } + + op->params["6"] = captured_params.at("align_corners").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_interpolate_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_layer_norm.cpp b/tools/pnnx/src/pass_ncnn/F_layer_norm.cpp new file mode 100644 index 000000000000..4ae1c5061c9c --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_layer_norm.cpp @@ -0,0 +1,118 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_layer_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.layer_norm op_0 1 1 input out weight=None bias=None normalized_shape=%normalized_shape eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "LayerNorm"; + } + + const char* name_str() const + { + return "ln"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& normalized_shape = captured_params.at("normalized_shape").ai; + int affine_size = normalized_shape[0]; + for (size_t i = 1; i < normalized_shape.size(); i++) + { + affine_size *= normalized_shape[i]; + } + + op->params["0"] = affine_size; + op->params["1"] = captured_params.at("eps"); + op->params["2"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_layer_norm, 20) + +class F_layer_norm_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.layer_norm op_0 3 1 input weight bias out normalized_shape=%normalized_shape eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "LayerNorm"; + } + + const char* name_str() const + { + return "ln"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + const std::vector& normalized_shape = captured_params.at("normalized_shape").ai; + int affine_size = normalized_shape[0]; + for (size_t i = 1; i < normalized_shape.size(); i++) + { + affine_size *= normalized_shape[i]; + } + + op->params["0"] = affine_size; + op->params["1"] = captured_params.at("eps"); + op->params["2"] = 1; + + op->attrs["0"] = weight; + op->attrs["1"] = bias; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_layer_norm_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_leaky_relu.cpp b/tools/pnnx/src/pass_ncnn/F_leaky_relu.cpp new file mode 100644 index 000000000000..f18a246eef8f --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_leaky_relu.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_leaky_relu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.leaky_relu op_0 1 1 input out negative_slope=%0 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ReLU"; + } + + const char* name_str() const + { + return "leakyrelu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_leaky_relu, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_linear.cpp b/tools/pnnx/src/pass_ncnn/F_linear.cpp new file mode 100644 index 000000000000..b76c444e4b63 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_linear.cpp @@ -0,0 +1,118 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_linear : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.linear op_0 2 1 input weight out bias=None +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InnerProduct"; + } + + const char* name_str() const + { + return "linear"; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + op->params["0"] = weight.shape[0]; + op->params["1"] = 0; + op->params["2"] = (int)(weight.data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = weight; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_linear, 20) + +class F_linear_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +pnnx.Attribute op_bias 0 1 bias @qwq +F.linear op_0 3 1 input weight bias out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InnerProduct"; + } + + const char* name_str() const + { + return "linear"; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + Attribute weight; + Attribute bias; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + if (x.first.substr(0, 8) == "op_bias.") + bias = x.second; + } + + op->params["0"] = weight.shape[0]; + op->params["1"] = 1; + op->params["2"] = (int)(weight.data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = weight; + op->attrs["2"] = bias; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_linear_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_local_response_norm.cpp b/tools/pnnx/src/pass_ncnn/F_local_response_norm.cpp new file mode 100644 index 000000000000..e4ee419b4dbb --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_local_response_norm.cpp @@ -0,0 +1,58 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_local_response_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.local_response_norm op_0 1 1 input out size=%size alpha=%alpha beta=%beta k=%k +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "LRN"; + } + + const char* name_str() const + { + return "lrn"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 0; // region_type ACROSS_CHANNELS + op->params["1"] = captured_params.at("size"); + op->params["2"] = captured_params.at("alpha"); + op->params["3"] = captured_params.at("beta"); + op->params["4"] = captured_params.at("k"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_local_response_norm, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_mish.cpp b/tools/pnnx/src/pass_ncnn/F_mish.cpp new file mode 100644 index 000000000000..93f3fe3896df --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_mish.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_mish : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.mish op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Mish"; + } + + const char* name_str() const + { + return "mish"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_mish, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_normalize.cpp b/tools/pnnx/src/pass_ncnn/F_normalize.cpp new file mode 100644 index 000000000000..60aa0b51fceb --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_normalize.cpp @@ -0,0 +1,71 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_normalize : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.normalize op_0 1 1 input out dim=%dim eps=%eps p=%p +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Normalize"; + } + + const char* name_str() const + { + return "norm"; + } + + void write(Operator* op, const std::map& captured_params) const + { + float p = 0.f; + if (captured_params.at("p").type == 2) + p = captured_params.at("p").i; + if (captured_params.at("p").type == 3) + p = captured_params.at("p").f; + + if (p != 2.f) + { + fprintf(stderr, "unsupported normalize p=%f\n", p); + return; + } + + op->params["1"] = 1; // channel_shared + op->params["2"] = captured_params.at("eps"); + op->params["3"] = 1; // scale_data_size + op->params["9"] = 1; // eps_mode + + op->attrs["0"] = Attribute({1}, std::vector(1, 1.f)); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_normalize, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_pad.cpp b/tools/pnnx/src/pass_ncnn/F_pad.cpp new file mode 100644 index 000000000000..0272fa02f11d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_pad.cpp @@ -0,0 +1,146 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_pad : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.pad op_0 1 1 input out pad=%pad mode=constant value=%value +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "pad"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& pad = captured_params.at("pad").ai; + + float pad_value = 0.f; + if (captured_params.at("value").type == 2) + pad_value = captured_params.at("value").i; + if (captured_params.at("value").type == 3) + pad_value = captured_params.at("value").f; + + if (pad.size() == 2) + { + op->params["0"] = 0; + op->params["1"] = 0; + op->params["2"] = pad[0]; + op->params["3"] = pad[1]; + } + else if (pad.size() >= 4) + { + op->params["0"] = pad[2]; + op->params["1"] = pad[3]; + op->params["2"] = pad[0]; + op->params["3"] = pad[1]; + } + if (pad.size() >= 6) + { + op->params["7"] = pad[4]; + op->params["8"] = pad[5]; + } + + op->params["4"] = 0; // constant + op->params["5"] = pad_value; + op->params["6"] = 0; // per_channel_pad_data_size + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_pad, 20) + +class F_pad_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.pad op_0 1 1 input out pad=%pad mode=%mode +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "pad"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& pad = captured_params.at("pad").ai; + const std::string& mode = captured_params.at("mode").s; + + if (pad.size() == 2) + { + op->params["0"] = 0; + op->params["1"] = 0; + op->params["2"] = pad[0]; + op->params["3"] = pad[1]; + } + else if (pad.size() >= 4) + { + op->params["0"] = pad[2]; + op->params["1"] = pad[3]; + op->params["2"] = pad[0]; + op->params["3"] = pad[1]; + } + if (pad.size() >= 6) + { + op->params["7"] = pad[4]; + op->params["8"] = pad[5]; + } + + if (mode == "constant") + op->params["4"] = 0; + if (mode == "reflect") + op->params["4"] = 2; + if (mode == "replicate") + op->params["4"] = 1; + + op->params["5"] = 0; // value + op->params["6"] = 0; // per_channel_pad_data_size + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_pad_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_pixel_shuffle.cpp b/tools/pnnx/src/pass_ncnn/F_pixel_shuffle.cpp new file mode 100644 index 000000000000..0cedb92891c4 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_pixel_shuffle.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_pixel_shuffle : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.pixel_shuffle op_0 1 1 input out upscale_factor=%upscale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "PixelShuffle"; + } + + const char* name_str() const + { + return "pixel_shuffle"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("upscale_factor"); + op->params["1"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_pixel_shuffle, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_pixel_unshuffle.cpp b/tools/pnnx/src/pass_ncnn/F_pixel_unshuffle.cpp new file mode 100644 index 000000000000..5791ba1a8544 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_pixel_unshuffle.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_pixel_unshuffle : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.pixel_unshuffle op_0 1 1 input out downscale_factor=%downscale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reorg"; + } + + const char* name_str() const + { + return "pixelunshuffle"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("downscale_factor"); + op->params["1"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_pixel_unshuffle, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_prelu.cpp b/tools/pnnx/src/pass_ncnn/F_prelu.cpp new file mode 100644 index 000000000000..14ae609224e3 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_prelu.cpp @@ -0,0 +1,64 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_prelu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.prelu op_0 2 1 input weight out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "PReLU"; + } + + const char* name_str() const + { + return "prelu"; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + op->params["0"] = weight.shape[0]; + + op->attrs["0"] = weight; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_prelu, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_relu.cpp b/tools/pnnx/src/pass_ncnn/F_relu.cpp new file mode 100644 index 000000000000..3a9794d570fa --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_relu.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_relu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.relu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ReLU"; + } + + const char* name_str() const + { + return "relu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_relu, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_relu6.cpp b/tools/pnnx/src/pass_ncnn/F_relu6.cpp new file mode 100644 index 000000000000..6ebd658a05e2 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_relu6.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_relu6 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.relu6 op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Clip"; + } + + const char* name_str() const + { + return "relu6"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 0.f; + op->params["1"] = 6.f; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_relu6, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_selu.cpp b/tools/pnnx/src/pass_ncnn/F_selu.cpp new file mode 100644 index 000000000000..eb64b798268a --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_selu.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_selu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.selu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "SELU"; + } + + const char* name_str() const + { + return "selu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_selu, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_sigmoid.cpp b/tools/pnnx/src/pass_ncnn/F_sigmoid.cpp new file mode 100644 index 000000000000..c6e78504669d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_sigmoid.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_sigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.sigmoid op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Sigmoid"; + } + + const char* name_str() const + { + return "sigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_sigmoid, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_silu.cpp b/tools/pnnx/src/pass_ncnn/F_silu.cpp new file mode 100644 index 000000000000..ccad64027788 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_silu.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_silu : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.silu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Swish"; + } + + const char* name_str() const + { + return "silu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_silu, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_softmax.cpp b/tools/pnnx/src/pass_ncnn/F_softmax.cpp new file mode 100644 index 000000000000..e85ad2f0f56a --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_softmax.cpp @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_softmax : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.softmax op_0 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Softmax"; + } + + const char* name_str() const + { + return "softmax"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int axis = captured_params.at("dim").i; + op->params["0"] = axis > 0 ? axis - 1 : axis; + op->params["1"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_softmax, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_tanh.cpp b/tools/pnnx/src/pass_ncnn/F_tanh.cpp new file mode 100644 index 000000000000..cb15285684cd --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_tanh.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_tanh : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.tanh op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "TanH"; + } + + const char* name_str() const + { + return "tanh"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_tanh, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_upsample.cpp b/tools/pnnx/src/pass_ncnn/F_upsample.cpp new file mode 100644 index 000000000000..6c61aa894bdb --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_upsample.cpp @@ -0,0 +1,247 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_upsample : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.upsample op_0 1 1 input out align_corners=%align_corners mode=%mode size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + const std::vector& size = captured_params.at("size").ai; + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (size.size() == 1) + { + op->params["3"] = 1; + op->params["4"] = size[0]; + } + else if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample size\n"); + } + + op->params["6"] = captured_params.at("align_corners").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_upsample, 20) + +class F_upsample_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.upsample op_0 1 1 input out align_corners=%align_corners mode=%mode scale_factor=%scale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + const std::vector& scale_factor = captured_params.at("scale_factor").af; + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (scale_factor.size() == 1) + { + op->params["1"] = 1.f; + op->params["2"] = scale_factor[0]; + } + else if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else + { + fprintf(stderr, "unsupported upsample scale_factor\n"); + } + + op->params["6"] = captured_params.at("align_corners").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_upsample_1, 20) + +class F_upsample_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.upsample op_0 1 1 input out mode=%mode size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + const std::vector& size = captured_params.at("size").ai; + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (size.size() == 1) + { + op->params["3"] = 1; + op->params["4"] = size[0]; + } + else if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample size\n"); + } + + op->params["6"] = 0; // align_corners + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_upsample_2, 20) + +class F_upsample_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.upsample op_0 1 1 input out mode=%mode scale_factor=%scale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + const std::vector& scale_factor = captured_params.at("scale_factor").af; + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (scale_factor.size() == 1) + { + op->params["1"] = 1.f; + op->params["2"] = scale_factor[0]; + } + else if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else + { + fprintf(stderr, "unsupported upsample scale_factor\n"); + } + + op->params["6"] = 0; // align_corners + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_upsample_3, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_upsample_bilinear.cpp b/tools/pnnx/src/pass_ncnn/F_upsample_bilinear.cpp new file mode 100644 index 000000000000..0b4dd0e47365 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_upsample_bilinear.cpp @@ -0,0 +1,113 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_upsample_bilinear : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.upsample_bilinear op_0 1 1 input out align_corners=%align_corners size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample_bilinear"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& size = captured_params.at("size").ai; + + op->params["0"] = 2; // bilinear + + if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample_bilinear size\n"); + } + + op->params["6"] = captured_params.at("align_corners").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_upsample_bilinear, 20) + +class F_upsample_bilinear_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.upsample_bilinear op_0 1 1 input out align_corners=%align_corners scale_factor=%scale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample_bilinear"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& scale_factor = captured_params.at("scale_factor").af; + + op->params["0"] = 2; // bilinear + + if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else + { + fprintf(stderr, "unsupported upsample_bilinear scale_factor\n"); + } + + op->params["6"] = captured_params.at("align_corners").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_upsample_bilinear_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_upsample_nearest.cpp b/tools/pnnx/src/pass_ncnn/F_upsample_nearest.cpp new file mode 100644 index 000000000000..897dbfbe2bc2 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_upsample_nearest.cpp @@ -0,0 +1,113 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_upsample_nearest : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.upsample_nearest op_0 1 1 input out size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample_nearest"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& size = captured_params.at("size").ai; + + op->params["0"] = 1; // nearest + + if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample_nearest size\n"); + } + + op->params["6"] = 0; // align_corners + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_upsample_nearest, 20) + +class F_upsample_nearest_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.upsample_nearest op_0 1 1 input out scale_factor=%scale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample_nearest"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& scale_factor = captured_params.at("scale_factor").af; + + op->params["0"] = 1; // nearest + + if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else + { + fprintf(stderr, "unsupported upsample_nearest scale_factor\n"); + } + + op->params["6"] = 0; // align_corners + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_upsample_nearest_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/Tensor_contiguous.cpp b/tools/pnnx/src/pass_ncnn/Tensor_contiguous.cpp new file mode 100644 index 000000000000..12f097b4ab86 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/Tensor_contiguous.cpp @@ -0,0 +1,53 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class Tensor_contiguous : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Tensor.contiguous op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Noop"; + } + + const char* name_str() const + { + return "contiguous"; + } + + void write(Operator* /*op*/, const std::map& /*captured_params*/) const + { + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_contiguous, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp b/tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp new file mode 100644 index 000000000000..7d2f31dba36e --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp @@ -0,0 +1,101 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class Tensor_reshape : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Tensor.reshape op_0 1 1 input out shape=%shape +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reshape"; + } + + const char* name_str() const + { + return "reshape"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& shape = captured_params.at("shape").ai; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (batch_index != 0) + { + fprintf(stderr, "reshape tensor with batch index %d is not supported yet!\n", batch_index); + } + + // drop shape batch index + std::vector new_shape; + for (int i = 0; i < (int)shape.size(); i++) + { + if (i == batch_index && shape[i] == 1) + continue; + + new_shape.push_back(shape[i]); + } + + const int shape_rank = (int)new_shape.size(); + + if (shape_rank > 5) + { + fprintf(stderr, "reshape to %d-rank tensor is not supported yet!\n", shape_rank); + return; + } + + if (shape_rank == 1) + { + op->params["0"] = new_shape[0]; + } + if (shape_rank == 2) + { + op->params["0"] = new_shape[1]; + op->params["1"] = new_shape[0]; + } + if (shape_rank == 3) + { + op->params["0"] = new_shape[2]; + op->params["1"] = new_shape[1]; + op->params["2"] = new_shape[0]; + } + if (shape_rank == 4) + { + op->params["0"] = new_shape[3] == -1 || new_shape[2] == -1 ? -1 : new_shape[3] * new_shape[2]; + op->params["1"] = new_shape[1]; + op->params["2"] = new_shape[0]; + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_reshape, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/Tensor_slice.cpp b/tools/pnnx/src/pass_ncnn/Tensor_slice.cpp new file mode 100644 index 000000000000..e0c7750ad670 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/Tensor_slice.cpp @@ -0,0 +1,102 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class Tensor_slice : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Tensor.slice op_0 1 1 input out dims=%dims starts=%starts ends=%ends steps=%steps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Crop"; + } + + const char* name_str() const + { + return "slice"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector axes = captured_params.at("dims").ai; + const std::vector& starts = captured_params.at("starts").ai; + std::vector ends = captured_params.at("ends").ai; + const std::vector& steps = captured_params.at("steps").ai; + int axes_rank = axes.size(); + + for (int i = 0; i < axes_rank; i++) + { + if (steps[i] != 1) + { + fprintf(stderr, "slice with step %d is not supported\n", steps[i]); + return; + } + } + + int input_rank = op->inputs[0]->shape.size(); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (input_rank > 5) + { + fprintf(stderr, "slice %d-rank tensor with %d-rank axes is not possible!\n", input_rank, axes_rank); + return; + } + + for (int i = 0; i < axes_rank; i++) + { + if (axes[i] == batch_index && (starts[i] != 0 || ends[i] != -1)) + { + fprintf(stderr, "slice along batch axis is not supported\n"); + return; + } + + if (axes[i] < 0) + { + int input_rank = op->inputs[0]->shape.size(); + axes[i] = input_rank + axes[i]; + } + + if (axes[i] > batch_index) + axes[i] -= 1; + + if (ends[i] == -1) + ends[i] = -233; + } + + op->params["9"] = starts; + op->params["10"] = ends; + op->params["11"] = axes; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_slice, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/Tensor_view.cpp b/tools/pnnx/src/pass_ncnn/Tensor_view.cpp new file mode 100644 index 000000000000..bb2113e3666b --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/Tensor_view.cpp @@ -0,0 +1,101 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class Tensor_view : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Tensor.view op_0 1 1 input out shape=%shape +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reshape"; + } + + const char* name_str() const + { + return "view"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& shape = captured_params.at("shape").ai; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (batch_index != 0) + { + fprintf(stderr, "reshape tensor with batch index %d is not supported yet!\n", batch_index); + } + + // drop shape batch index + std::vector new_shape; + for (int i = 0; i < (int)shape.size(); i++) + { + if (i == batch_index && shape[i] == 1) + continue; + + new_shape.push_back(shape[i]); + } + + const int shape_rank = (int)new_shape.size(); + + if (shape_rank > 5) + { + fprintf(stderr, "reshape to %d-rank tensor is not supported yet!\n", shape_rank); + return; + } + + if (shape_rank == 1) + { + op->params["0"] = new_shape[0]; + } + if (shape_rank == 2) + { + op->params["0"] = new_shape[1]; + op->params["1"] = new_shape[0]; + } + if (shape_rank == 3) + { + op->params["0"] = new_shape[2]; + op->params["1"] = new_shape[1]; + op->params["2"] = new_shape[0]; + } + if (shape_rank == 4) + { + op->params["0"] = new_shape[3] == -1 || new_shape[2] == -1 ? -1 : new_shape[3] * new_shape[2]; + op->params["1"] = new_shape[1]; + op->params["2"] = new_shape[0]; + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_view, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/chain_multi_output.cpp b/tools/pnnx/src/pass_ncnn/chain_multi_output.cpp new file mode 100644 index 000000000000..b12379ffb266 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/chain_multi_output.cpp @@ -0,0 +1,127 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "chain_multi_output.h" + +#include + +namespace pnnx { + +namespace ncnn { + +void chain_multi_output(Graph& graph) +{ + for (;;) + { + bool need_eliminate = false; + + for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + { + Operator* op = graph.ops[i]; + + if (op->type != "pnnx.Output") + continue; + + // prim::TupleConstruct pnnx_791 2 1 a b out + // pnnx.Expression pnnx_expr_0 3 1 a b c out expr=[@0,@1,@2] + // pnnx.Output pnnx_output_0 1 0 out + bool match_tuple_expr_output = false; + for (int j = 0; j < (int)op->inputs.size(); j++) + { + Operand* r = op->inputs[j]; + + if (r->consumers.size() != 1) + continue; + + Operator* op0 = r->producer; + + if (op0->type == "prim::TupleConstruct") + { + match_tuple_expr_output = true; + } + else if (op0->type == "pnnx.Expression") + { + const int op_expr_input_count = (int)op0->inputs.size(); + const std::string& expr = op0->params.at("expr").s; + + std::string pattern_expr = "["; + for (int k = 0; k < op_expr_input_count; k++) + { + pattern_expr += std::string("@") + std::to_string(k); + + if (k != op_expr_input_count - 1) + pattern_expr += ","; + } + pattern_expr += "]"; + + if (expr == pattern_expr) + { + match_tuple_expr_output = true; + } + } + + if (!match_tuple_expr_output) + continue; + + // chain op0 as output and delete op0 + std::vector new_inputs; + for (int k = 0; k < j; k++) + { + new_inputs.push_back(op->inputs[k]); + } + + for (Operand* r : op0->inputs) + { + r->remove_consumer(op0); + r->consumers.push_back(op); + new_inputs.push_back(r); + } + + for (int k = j + 1; k < (int)op->inputs.size(); k++) + { + new_inputs.push_back(op->inputs[k]); + } + + op->inputs = new_inputs; + + op0->inputs.clear(); + op0->outputs.clear(); + + Operand* op0_out = op0->outputs[0]; + op0_out->producer = 0; + op0_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op0_out)); + delete op0_out; + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op0)); + delete op0; + + break; + } + + if (match_tuple_expr_output) + need_eliminate = true; + + break; + } + + if (!need_eliminate) + break; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/chain_multi_output.h b/tools/pnnx/src/pass_ncnn/chain_multi_output.h new file mode 100644 index 000000000000..619ede9e3c0e --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/chain_multi_output.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void chain_multi_output(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_attribute.cpp b/tools/pnnx/src/pass_ncnn/convert_attribute.cpp new file mode 100644 index 000000000000..007b1b938cac --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_attribute.cpp @@ -0,0 +1,80 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_attribute.h" + +namespace pnnx { + +namespace ncnn { + +void convert_attribute(Graph& graph) +{ + for (Operator* op : graph.ops) + { + if (op->type != "pnnx.Attribute") + continue; + + op->type = "MemoryData"; + + const std::string& key = op->attrs.begin()->first; + const Attribute& data = op->attrs.begin()->second; + + const int batch_index = op->outputs[0]->params["__batch_index"].i; + + if ((int)data.shape.size() > 5) + { + fprintf(stderr, "pnnx attribute %d-rank tensor is not supported yet!\n", (int)data.shape.size()); + return; + } + + // drop batch index + std::vector new_shape; + for (int i = 0; i < (int)data.shape.size(); i++) + { + if (i == batch_index) + continue; + + new_shape.push_back(data.shape[i]); + } + + if (new_shape.size() == 1) + { + op->params["0"] = new_shape[0]; + } + if (new_shape.size() == 2) + { + op->params["0"] = new_shape[1]; + op->params["1"] = new_shape[0]; + } + if (new_shape.size() == 3) + { + op->params["0"] = new_shape[2]; + op->params["1"] = new_shape[1]; + op->params["2"] = new_shape[0]; + } + if (new_shape.size() == 4) + { + op->params["0"] = new_shape[2] * new_shape[3]; + op->params["1"] = new_shape[1]; + op->params["2"] = new_shape[0]; + } + + op->attrs["0"] = data; + op->attrs.erase(key); + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_attribute.h b/tools/pnnx/src/pass_ncnn/convert_attribute.h new file mode 100644 index 000000000000..d7345dc9d020 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_attribute.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_attribute(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_custom_op.cpp b/tools/pnnx/src/pass_ncnn/convert_custom_op.cpp new file mode 100644 index 000000000000..e4d8046232b3 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_custom_op.cpp @@ -0,0 +1,47 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_custom_op.h" + +namespace pnnx { + +namespace ncnn { + +void convert_custom_op(Graph& graph) +{ + for (Operator* op : graph.ops) + { + if (op->type.substr(0, 15) == "pnnx.custom_op.") + { + op->type = op->type.substr(15); + + // handle arg_N + std::map new_params; + for (const auto& it : op->params) + { + fprintf(stderr, "%s %d\n", it.first.c_str(), it.second.type); + if (it.first.substr(0, 4) == "arg_") + { + new_params[it.first.substr(4)] = it.second; + } + } + + op->params = new_params; + } + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_custom_op.h b/tools/pnnx/src/pass_ncnn/convert_custom_op.h new file mode 100644 index 000000000000..3ee1455eda35 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_custom_op.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_custom_op(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_input.cpp b/tools/pnnx/src/pass_ncnn/convert_input.cpp new file mode 100644 index 000000000000..71a471519cb0 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_input.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_input.h" + +namespace pnnx { + +namespace ncnn { + +void convert_input(Graph& graph) +{ + int index = 0; + + for (Operator* op : graph.ops) + { + if (op->type != "pnnx.Input") + continue; + + op->type = "Input"; + op->name = std::string("in") + std::to_string(index); + + // canonicalize output name + op->outputs[0]->name = std::string("in") + std::to_string(index); + index++; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_input.h b/tools/pnnx/src/pass_ncnn/convert_input.h new file mode 100644 index 000000000000..328c18cfb937 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_input.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_input(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_cat.cpp b/tools/pnnx/src/pass_ncnn/convert_torch_cat.cpp new file mode 100644 index 000000000000..e15c24b51729 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_cat.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_torch_cat.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_cat(Graph& graph) +{ + int op_index = 0; + + for (Operator* op : graph.ops) + { + if (op->type != "torch.cat") + continue; + + op->type = "Concat"; + op->name = std::string("cat_") + std::to_string(op_index++); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + int axis = op->params.at("dim").i; + if (axis == batch_index) + { + fprintf(stderr, "cat along batch axis %d is not supported\n", batch_index); + continue; + } + + if (axis < 0) + { + int input_rank = op->inputs[0]->shape.size(); + axis = input_rank + axis; + } + + if (axis > batch_index) + axis -= 1; + + op->params["0"] = axis; + + op->params.erase("dim"); + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_cat.h b/tools/pnnx/src/pass_ncnn/convert_torch_cat.h new file mode 100644 index 000000000000..c7f4a26e662d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_cat.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_cat(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_chunk.cpp b/tools/pnnx/src/pass_ncnn/convert_torch_chunk.cpp new file mode 100644 index 000000000000..9ef40352f2b3 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_chunk.cpp @@ -0,0 +1,64 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_torch_chunk.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_chunk(Graph& graph) +{ + int op_index = 0; + + for (Operator* op : graph.ops) + { + if (op->type != "torch.chunk") + continue; + + op->type = "Slice"; + op->name = std::string("chunk_") + std::to_string(op_index++); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + int axis = op->params.at("dim").i; + if (axis == batch_index) + { + fprintf(stderr, "chunk along batch axis %d is not supported\n", batch_index); + continue; + } + + if (axis < 0) + { + int input_rank = op->inputs[0]->shape.size(); + axis = input_rank + axis; + } + + if (axis > batch_index) + axis -= 1; + + int chunks = op->params.at("chunks").i; + op->params["0"].type = 5; + op->params["0"].ai.resize(chunks, -233); + + op->params["1"] = axis; + + op->params.erase("chunks"); + op->params.erase("dim"); + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_chunk.h b/tools/pnnx/src/pass_ncnn/convert_torch_chunk.h new file mode 100644 index 000000000000..3f3002573bbc --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_chunk.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_chunk(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_split.cpp b/tools/pnnx/src/pass_ncnn/convert_torch_split.cpp new file mode 100644 index 000000000000..bd18ca9a8db2 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_split.cpp @@ -0,0 +1,79 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_torch_split.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_split(Graph& graph) +{ + int op_index = 0; + + for (Operator* op : graph.ops) + { + if (op->type != "torch.split") + continue; + + op->type = "Slice"; + op->name = std::string("split_") + std::to_string(op_index++); + + const Parameter& split_size_or_sections = op->params.at("split_size_or_sections"); + if (split_size_or_sections.type != 1 && split_size_or_sections.type != 5) + { + fprintf(stderr, "malformed split split_size_or_sections type %d\n", split_size_or_sections.type); + continue; + } + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + int axis = op->params.at("dim").i; + if (axis == batch_index) + { + fprintf(stderr, "split along batch axis %d is not supported\n", batch_index); + continue; + } + + if (axis < 0) + { + int input_rank = op->inputs[0]->shape.size(); + axis = input_rank + axis; + } + + if (axis > batch_index) + axis -= 1; + + if (split_size_or_sections.type == 1) + { + const size_t output_size = op->outputs.size(); + op->params["0"].type = 5; + op->params["0"].ai.resize(output_size, split_size_or_sections.i); + op->params["0"].ai[output_size - 1] = -233; + } + else + { + op->params["0"] = split_size_or_sections; + } + + op->params["1"] = axis; + + op->params.erase("split_size_or_sections"); + op->params.erase("dim"); + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_split.h b/tools/pnnx/src/pass_ncnn/convert_torch_split.h new file mode 100644 index 000000000000..2ec4c8246de0 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_torch_split.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_torch_split(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/eliminate_noop.cpp b/tools/pnnx/src/pass_ncnn/eliminate_noop.cpp new file mode 100644 index 000000000000..4a2a92eb2764 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/eliminate_noop.cpp @@ -0,0 +1,75 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_noop.h" + +#include + +namespace pnnx { + +namespace ncnn { + +void eliminate_noop(Graph& graph) +{ + for (;;) + { + bool need_eliminate = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "Noop") + continue; + + need_eliminate = true; + + op->inputs[0]->remove_consumer(op); + + Operand* op_out = op->outputs[0]; + + for (auto& x : op_out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == op_out) + x->inputs[j] = op->inputs[0]; + } + + op->inputs[0]->consumers.push_back(x); + } + + op_out->producer = 0; + op_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), op_out)); + delete op_out; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + + if (!need_eliminate) + break; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/eliminate_noop.h b/tools/pnnx/src/pass_ncnn/eliminate_noop.h new file mode 100644 index 000000000000..bfe278453038 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/eliminate_noop.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void eliminate_noop(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/eliminate_output.cpp b/tools/pnnx/src/pass_ncnn/eliminate_output.cpp new file mode 100644 index 000000000000..31f73f6e9db5 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/eliminate_output.cpp @@ -0,0 +1,69 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_output.h" + +namespace pnnx { + +namespace ncnn { + +void eliminate_output(Graph& graph) +{ + for (;;) + { + bool need_eliminate = false; + + for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + { + Operator* op = graph.ops[i]; + + if (op->type != "pnnx.Output") + continue; + + need_eliminate = true; + + // canonicalize output name + for (int j = 0; j < (int)op->inputs.size(); j++) + { + op->inputs[j]->name = std::string("out") + std::to_string(j); + } + + for (Operand* r : op->inputs) + { + r->remove_consumer(op); + } + + op->inputs.clear(); + + for (Operand* r : op->outputs) + { + r->producer = 0; + } + + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + + if (!need_eliminate) + break; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/eliminate_output.h b/tools/pnnx/src/pass_ncnn/eliminate_output.h new file mode 100644 index 000000000000..49149c77cf56 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/eliminate_output.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void eliminate_output(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.cpp b/tools/pnnx/src/pass_ncnn/expand_expression.cpp new file mode 100644 index 000000000000..f19ab5b726fe --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/expand_expression.cpp @@ -0,0 +1,309 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +#include +#include + +#include +#include +#include +#include +#include + +namespace pnnx { + +namespace ncnn { + +static bool token_is_argument(const std::string& t) +{ + if (t[0] != '@' || t.size() < 2) + return false; + + for (size_t i = 1; i < t.size(); i++) + { + if (t[i] < '0' || t[i] > '9') + return false; + } + + return true; +} + +static bool token_is_literal(const std::string& t) +{ + std::istringstream iss(t); + float f; + iss >> std::noskipws >> f; + return iss.eof() && !iss.fail(); + + // for (size_t i = 0; i < t.size(); i++) + // { + // if (i == 0 && t[i] == '-') + // continue; + // + // if (t[i] < '0' || t[i] > '9') + // { + // if (t[i] != '.' && t[i] != 'e') + // return false; + // } + // } + // + // return true; +} + +static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx_expr_index) +{ + std::string expr = op->params.at("expr").s; + + // fprintf(stderr, "ncnn expand_expression %s\n", expr.c_str()); + + // split into tokens + std::vector tokens; + { + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '[') // list + { + t += ch; + tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + } + + // scan and stack + std::stack exprstack; + for (int i = (int)tokens.size() - 1; i >= 0; i--) + { + const std::string& t = tokens[i]; + + if (t == "size") + { + // not supported + return std::string(); + } + else if (t == "int") + { + // not supported + return std::string(); + } + else if (t == "sqrt" || t == "rsqrt" || t == "neg") + { + std::string a = exprstack.top(); + exprstack.pop(); + + std::string r = t + "(" + (token_is_argument(a) ? op->inputs[std::stoi(a.substr(1))]->name : a) + ")"; + exprstack.push(r); + + Operator* op_unary = graph.new_operator_before("UnaryOp", t + "_" + std::to_string(pnnx_expr_index++), op); + + if (t == "sqrt") op_unary->params["0"] = 5; + if (t == "rsqrt") op_unary->params["0"] = 6; + if (t == "neg") op_unary->params["0"] = 1; + + Operand* op_unary_in = token_is_argument(a) ? op->inputs[std::stoi(a.substr(1))] : graph.get_operand(op->name + "_" + a); + op_unary_in->consumers.push_back(op_unary); + + Operand* op_unary_out = graph.new_operand(op->name + "_" + r); + op_unary_out->producer = op_unary; + + op_unary->inputs.push_back(op_unary_in); + op_unary->outputs.push_back(op_unary_out); + } + else if (t == "add" || t == "sub" || t == "mul" || t == "div" || /*t == "floor_divide" || */ t == "pow") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = t + "(" + (token_is_argument(a) ? op->inputs[std::stoi(a.substr(1))]->name : a) + "," + (token_is_argument(b) ? op->inputs[std::stoi(b.substr(1))]->name : b) + ")"; + exprstack.push(r); + + Operator* op_binary = graph.new_operator_before("BinaryOp", t + "_" + std::to_string(pnnx_expr_index++), op); + + if (t == "add") op_binary->params["0"] = 0; + if (t == "sub") op_binary->params["0"] = 1; + if (t == "mul") op_binary->params["0"] = 2; + if (t == "div") op_binary->params["0"] = 3; + if (t == "pow") op_binary->params["0"] = 6; + + if (token_is_literal(a)) + { + if (t == "sub") op_binary->params["0"] = 7; + if (t == "div") op_binary->params["0"] = 8; + + Operand* op_binary_inb = token_is_argument(b) ? op->inputs[std::stoi(b.substr(1))] : graph.get_operand(op->name + "_" + b); + op_binary_inb->consumers.push_back(op_binary); + + op_binary->params["1"] = 1; // with_scalar + op_binary->params["2"] = std::stof(a); + + Operand* op_binary_out = graph.new_operand(op->name + "_" + r); + op_binary_out->producer = op_binary; + + op_binary->inputs.push_back(op_binary_inb); + op_binary->outputs.push_back(op_binary_out); + } + else if (token_is_literal(b)) + { + Operand* op_binary_ina = token_is_argument(a) ? op->inputs[std::stoi(a.substr(1))] : graph.get_operand(op->name + "_" + a); + op_binary_ina->consumers.push_back(op_binary); + + op_binary->params["1"] = 1; // with_scalar + op_binary->params["2"] = std::stof(b); + + Operand* op_binary_out = graph.new_operand(op->name + "_" + r); + op_binary_out->producer = op_binary; + + op_binary->inputs.push_back(op_binary_ina); + op_binary->outputs.push_back(op_binary_out); + } + else + { + Operand* op_binary_ina = token_is_argument(a) ? op->inputs[std::stoi(a.substr(1))] : graph.get_operand(op->name + "_" + a); + op_binary_ina->consumers.push_back(op_binary); + + Operand* op_binary_inb = token_is_argument(b) ? op->inputs[std::stoi(b.substr(1))] : graph.get_operand(op->name + "_" + b); + op_binary_inb->consumers.push_back(op_binary); + + Operand* op_binary_out = graph.new_operand(op->name + "_" + r); + op_binary_out->producer = op_binary; + + op_binary->inputs.push_back(op_binary_ina); + op_binary->inputs.push_back(op_binary_inb); + op_binary->outputs.push_back(op_binary_out); + } + } + else if (t == "[") // list + { + // not supported + return std::string(); + } + else if (t[0] == '@') + { + exprstack.push(t); + } + else + { + // literal + exprstack.push(t); + } + } + + std::string r = exprstack.top(); + exprstack.pop(); + + // fprintf(stderr, "expand_expression return %s\n", r.c_str()); + + return r; +} + +void expand_expression(Graph& graph) +{ + int pnnx_expr_index = 0; + + std::set nonsupported_expr_ops; + + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + if (op->type != "pnnx.Expression") + continue; + + if (nonsupported_expr_ops.find(op) != nonsupported_expr_ops.end()) + continue; + + matched = true; + + std::string outname = expand_expression(graph, op, pnnx_expr_index); + + if (outname.empty()) + { + // not supported expr + nonsupported_expr_ops.insert(op); + break; + } + + // link new output + Operand* old_output_operand = op->outputs[0]; + Operand* new_output_operand = graph.get_operand(op->name + "_" + outname); + + for (auto r : op->inputs) + { + r->remove_consumer(op); + } + + for (auto& x : old_output_operand->consumers) + { + new_output_operand->consumers.push_back(x); + + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == old_output_operand) + { + x->inputs[j] = new_output_operand; + } + } + } + + new_output_operand->type = old_output_operand->type; + new_output_operand->shape = old_output_operand->shape; + new_output_operand->params = old_output_operand->params; + + old_output_operand->producer = 0; + old_output_operand->consumers.clear(); + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op)); + delete op; + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), old_output_operand)); + delete old_output_operand; + + break; + } + + if (!matched) + break; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.h b/tools/pnnx/src/pass_ncnn/expand_expression.h new file mode 100644 index 000000000000..3e9e7c049354 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/expand_expression.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void expand_expression(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_convolution1d_activation.cpp b/tools/pnnx/src/pass_ncnn/fuse_convolution1d_activation.cpp new file mode 100644 index 000000000000..33c7cc53565d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_convolution1d_activation.cpp @@ -0,0 +1,279 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_convolution1d_activation.h" + +#include "pass_level2.h" + +#include + +namespace pnnx { + +namespace ncnn { + +class fuse_convolution1d_relu_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Convolution1D op_0 1 1 input a %*=%* +ReLU op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution1D"; + } + + const char* name_str() const + { + return "conv1drelu"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float slope = 0.f; + if (captured_params.find("op_1.0") != captured_params.end()) + { + slope = captured_params.at("op_1.0").f; + } + + if (slope == 0.f) + { + op->params["9"] = 1; + } + else + { + op->params["9"] = 2; + op->params["10"] = Parameter{slope}; + } + } +}; + +class fuse_convolution1d_clip_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Convolution1D op_0 1 1 input a %*=%* +Clip op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution1D"; + } + + const char* name_str() const + { + return "conv1dclip"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float min = -FLT_MAX; + float max = FLT_MAX; + if (captured_params.find("op_1.0") != captured_params.end()) + { + min = captured_params.at("op_1.0").f; + } + if (captured_params.find("op_1.1") != captured_params.end()) + { + max = captured_params.at("op_1.1").f; + } + + op->params["9"] = 3; + op->params["10"] = Parameter{min, max}; + } +}; + +class fuse_convolution1d_sigmoid_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Convolution1D op_0 1 1 input a %*=%* +Sigmoid op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution1D"; + } + + const char* name_str() const + { + return "conv1dsigmoid"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 4; + } +}; + +class fuse_convolution1d_mish_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Convolution1D op_0 1 1 input a %*=%* +Mish op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution1D"; + } + + const char* name_str() const + { + return "conv1dmish"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 5; + } +}; + +void fuse_convolution1d_activation(Graph& graph) +{ + fuse_convolution1d_relu_pass a; + fuse_convolution1d_clip_pass b; + fuse_convolution1d_sigmoid_pass c; + fuse_convolution1d_mish_pass d; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_convolution1d_activation.h b/tools/pnnx/src/pass_ncnn/fuse_convolution1d_activation.h new file mode 100644 index 000000000000..91fefe89ba03 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_convolution1d_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace ncnn { + +void fuse_convolution1d_activation(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_convolution_activation.cpp b/tools/pnnx/src/pass_ncnn/fuse_convolution_activation.cpp new file mode 100644 index 000000000000..577d53e30e3a --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_convolution_activation.cpp @@ -0,0 +1,279 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_convolution_activation.h" + +#include "pass_level2.h" + +#include + +namespace pnnx { + +namespace ncnn { + +class fuse_convolution_relu_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Convolution op_0 1 1 input a %*=%* +ReLU op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution"; + } + + const char* name_str() const + { + return "convrelu"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float slope = 0.f; + if (captured_params.find("op_1.0") != captured_params.end()) + { + slope = captured_params.at("op_1.0").f; + } + + if (slope == 0.f) + { + op->params["9"] = 1; + } + else + { + op->params["9"] = 2; + op->params["10"] = Parameter{slope}; + } + } +}; + +class fuse_convolution_clip_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Convolution op_0 1 1 input a %*=%* +Clip op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution"; + } + + const char* name_str() const + { + return "convclip"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float min = -FLT_MAX; + float max = FLT_MAX; + if (captured_params.find("op_1.0") != captured_params.end()) + { + min = captured_params.at("op_1.0").f; + } + if (captured_params.find("op_1.1") != captured_params.end()) + { + max = captured_params.at("op_1.1").f; + } + + op->params["9"] = 3; + op->params["10"] = Parameter{min, max}; + } +}; + +class fuse_convolution_sigmoid_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Convolution op_0 1 1 input a %*=%* +Sigmoid op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution"; + } + + const char* name_str() const + { + return "convsigmoid"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 4; + } +}; + +class fuse_convolution_mish_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Convolution op_0 1 1 input a %*=%* +Mish op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution"; + } + + const char* name_str() const + { + return "convmish"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 5; + } +}; + +void fuse_convolution_activation(Graph& graph) +{ + fuse_convolution_relu_pass a; + fuse_convolution_clip_pass b; + fuse_convolution_sigmoid_pass c; + fuse_convolution_mish_pass d; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_convolution_activation.h b/tools/pnnx/src/pass_ncnn/fuse_convolution_activation.h new file mode 100644 index 000000000000..cbd47d4fa34e --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_convolution_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace ncnn { + +void fuse_convolution_activation(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise1d_activation.cpp b/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise1d_activation.cpp new file mode 100644 index 000000000000..ded8aa294ef5 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise1d_activation.cpp @@ -0,0 +1,279 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_convolutiondepthwise1d_activation.h" + +#include "pass_level2.h" + +#include + +namespace pnnx { + +namespace ncnn { + +class fuse_convolutiondepthwise1d_relu_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ConvolutionDepthWise1D op_0 1 1 input a %*=%* +ReLU op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise1D"; + } + + const char* name_str() const + { + return "convdw1drelu"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float slope = 0.f; + if (captured_params.find("op_1.0") != captured_params.end()) + { + slope = captured_params.at("op_1.0").f; + } + + if (slope == 0.f) + { + op->params["9"] = 1; + } + else + { + op->params["9"] = 2; + op->params["10"] = Parameter{slope}; + } + } +}; + +class fuse_convolutiondepthwise1d_clip_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ConvolutionDepthWise1D op_0 1 1 input a %*=%* +Clip op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise1D"; + } + + const char* name_str() const + { + return "convdw1dclip"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float min = -FLT_MAX; + float max = FLT_MAX; + if (captured_params.find("op_1.0") != captured_params.end()) + { + min = captured_params.at("op_1.0").f; + } + if (captured_params.find("op_1.1") != captured_params.end()) + { + max = captured_params.at("op_1.1").f; + } + + op->params["9"] = 3; + op->params["10"] = Parameter{min, max}; + } +}; + +class fuse_convolutiondepthwise1d_sigmoid_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ConvolutionDepthWise1D op_0 1 1 input a %*=%* +Sigmoid op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise1D"; + } + + const char* name_str() const + { + return "convdw1dsigmoid"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 4; + } +}; + +class fuse_convolutiondepthwise1d_mish_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ConvolutionDepthWise1D op_0 1 1 input a %*=%* +Mish op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise1D"; + } + + const char* name_str() const + { + return "convdw1dmish"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 5; + } +}; + +void fuse_convolutiondepthwise1d_activation(Graph& graph) +{ + fuse_convolutiondepthwise1d_relu_pass a; + fuse_convolutiondepthwise1d_clip_pass b; + fuse_convolutiondepthwise1d_sigmoid_pass c; + fuse_convolutiondepthwise1d_mish_pass d; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise1d_activation.h b/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise1d_activation.h new file mode 100644 index 000000000000..624f8607ac62 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise1d_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace ncnn { + +void fuse_convolutiondepthwise1d_activation(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise_activation.cpp b/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise_activation.cpp new file mode 100644 index 000000000000..c23f59054274 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise_activation.cpp @@ -0,0 +1,279 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_convolutiondepthwise_activation.h" + +#include "pass_level2.h" + +#include + +namespace pnnx { + +namespace ncnn { + +class fuse_convolutiondepthwise_relu_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ConvolutionDepthWise op_0 1 1 input a %*=%* +ReLU op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise"; + } + + const char* name_str() const + { + return "convdwrelu"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float slope = 0.f; + if (captured_params.find("op_1.0") != captured_params.end()) + { + slope = captured_params.at("op_1.0").f; + } + + if (slope == 0.f) + { + op->params["9"] = 1; + } + else + { + op->params["9"] = 2; + op->params["10"] = Parameter{slope}; + } + } +}; + +class fuse_convolutiondepthwise_clip_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ConvolutionDepthWise op_0 1 1 input a %*=%* +Clip op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise"; + } + + const char* name_str() const + { + return "convdwclip"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float min = -FLT_MAX; + float max = FLT_MAX; + if (captured_params.find("op_1.0") != captured_params.end()) + { + min = captured_params.at("op_1.0").f; + } + if (captured_params.find("op_1.1") != captured_params.end()) + { + max = captured_params.at("op_1.1").f; + } + + op->params["9"] = 3; + op->params["10"] = Parameter{min, max}; + } +}; + +class fuse_convolutiondepthwise_sigmoid_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ConvolutionDepthWise op_0 1 1 input a %*=%* +Sigmoid op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise"; + } + + const char* name_str() const + { + return "convdwsigmoid"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 4; + } +}; + +class fuse_convolutiondepthwise_mish_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ConvolutionDepthWise op_0 1 1 input a %*=%* +Mish op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise"; + } + + const char* name_str() const + { + return "convdwmish"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 5; + } +}; + +void fuse_convolutiondepthwise_activation(Graph& graph) +{ + fuse_convolutiondepthwise_relu_pass a; + fuse_convolutiondepthwise_clip_pass b; + fuse_convolutiondepthwise_sigmoid_pass c; + fuse_convolutiondepthwise_mish_pass d; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise_activation.h b/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise_activation.h new file mode 100644 index 000000000000..cd3b6bbf9e4d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_convolutiondepthwise_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace ncnn { + +void fuse_convolutiondepthwise_activation(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_deconvolution_activation.cpp b/tools/pnnx/src/pass_ncnn/fuse_deconvolution_activation.cpp new file mode 100644 index 000000000000..abf405f048e3 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_deconvolution_activation.cpp @@ -0,0 +1,279 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_deconvolution_activation.h" + +#include "pass_level2.h" + +#include + +namespace pnnx { + +namespace ncnn { + +class fuse_deconvolution_relu_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Deconvolution op_0 1 1 input a %*=%* +ReLU op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Deconvolution"; + } + + const char* name_str() const + { + return "deconvrelu"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float slope = 0.f; + if (captured_params.find("op_1.0") != captured_params.end()) + { + slope = captured_params.at("op_1.0").f; + } + + if (slope == 0.f) + { + op->params["9"] = 1; + } + else + { + op->params["9"] = 2; + op->params["10"] = Parameter{slope}; + } + } +}; + +class fuse_deconvolution_clip_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Deconvolution op_0 1 1 input a %*=%* +Clip op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Deconvolution"; + } + + const char* name_str() const + { + return "deconvclip"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float min = -FLT_MAX; + float max = FLT_MAX; + if (captured_params.find("op_1.0") != captured_params.end()) + { + min = captured_params.at("op_1.0").f; + } + if (captured_params.find("op_1.1") != captured_params.end()) + { + max = captured_params.at("op_1.1").f; + } + + op->params["9"] = 3; + op->params["10"] = Parameter{min, max}; + } +}; + +class fuse_deconvolution_sigmoid_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Deconvolution op_0 1 1 input a %*=%* +Sigmoid op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Deconvolution"; + } + + const char* name_str() const + { + return "deconvsigmoid"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 4; + } +}; + +class fuse_deconvolution_mish_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Deconvolution op_0 1 1 input a %*=%* +Mish op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Deconvolution"; + } + + const char* name_str() const + { + return "deconvmish"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 5; + } +}; + +void fuse_deconvolution_activation(Graph& graph) +{ + fuse_deconvolution_relu_pass a; + fuse_deconvolution_clip_pass b; + fuse_deconvolution_sigmoid_pass c; + fuse_deconvolution_mish_pass d; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_deconvolution_activation.h b/tools/pnnx/src/pass_ncnn/fuse_deconvolution_activation.h new file mode 100644 index 000000000000..14fc2184a7b8 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_deconvolution_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace ncnn { + +void fuse_deconvolution_activation(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp b/tools/pnnx/src/pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp new file mode 100644 index 000000000000..98c70ca29155 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_deconvolutiondepthwise_activation.cpp @@ -0,0 +1,279 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_deconvolutiondepthwise_activation.h" + +#include "pass_level2.h" + +#include + +namespace pnnx { + +namespace ncnn { + +class fuse_deconvolutiondepthwise_relu_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +DeconvolutionDepthWise op_0 1 1 input a %*=%* +ReLU op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "DeconvolutionDepthWise"; + } + + const char* name_str() const + { + return "deconvdwrelu"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float slope = 0.f; + if (captured_params.find("op_1.0") != captured_params.end()) + { + slope = captured_params.at("op_1.0").f; + } + + if (slope == 0.f) + { + op->params["9"] = 1; + } + else + { + op->params["9"] = 2; + op->params["10"] = Parameter{slope}; + } + } +}; + +class fuse_deconvolutiondepthwise_clip_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +DeconvolutionDepthWise op_0 1 1 input a %*=%* +Clip op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "DeconvolutionDepthWise"; + } + + const char* name_str() const + { + return "deconvdwclip"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float min = -FLT_MAX; + float max = FLT_MAX; + if (captured_params.find("op_1.0") != captured_params.end()) + { + min = captured_params.at("op_1.0").f; + } + if (captured_params.find("op_1.1") != captured_params.end()) + { + max = captured_params.at("op_1.1").f; + } + + op->params["9"] = 3; + op->params["10"] = Parameter{min, max}; + } +}; + +class fuse_deconvolutiondepthwise_sigmoid_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +DeconvolutionDepthWise op_0 1 1 input a %*=%* +Sigmoid op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "DeconvolutionDepthWise"; + } + + const char* name_str() const + { + return "deconvdwsigmoid"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 4; + } +}; + +class fuse_deconvolutiondepthwise_mish_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +DeconvolutionDepthWise op_0 1 1 input a %*=%* +Mish op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "DeconvolutionDepthWise"; + } + + const char* name_str() const + { + return "deconvdwmish"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 5; + } +}; + +void fuse_deconvolutiondepthwise_activation(Graph& graph) +{ + fuse_deconvolutiondepthwise_relu_pass a; + fuse_deconvolutiondepthwise_clip_pass b; + fuse_deconvolutiondepthwise_sigmoid_pass c; + fuse_deconvolutiondepthwise_mish_pass d; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_deconvolutiondepthwise_activation.h b/tools/pnnx/src/pass_ncnn/fuse_deconvolutiondepthwise_activation.h new file mode 100644 index 000000000000..befa7affb519 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_deconvolutiondepthwise_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace ncnn { + +void fuse_deconvolutiondepthwise_activation(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_innerproduct_activation.cpp b/tools/pnnx/src/pass_ncnn/fuse_innerproduct_activation.cpp new file mode 100644 index 000000000000..67ac1f71ae07 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_innerproduct_activation.cpp @@ -0,0 +1,279 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_innerproduct_activation.h" + +#include "pass_level2.h" + +#include + +namespace pnnx { + +namespace ncnn { + +class fuse_innerproduct_relu_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +InnerProduct op_0 1 1 input a %*=%* +ReLU op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InnerProduct"; + } + + const char* name_str() const + { + return "fcrelu"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float slope = 0.f; + if (captured_params.find("op_1.0") != captured_params.end()) + { + slope = captured_params.at("op_1.0").f; + } + + if (slope == 0.f) + { + op->params["9"] = 1; + } + else + { + op->params["9"] = 2; + op->params["10"] = Parameter{slope}; + } + } +}; + +class fuse_innerproduct_clip_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +InnerProduct op_0 1 1 input a %*=%* +Clip op_1 1 1 a out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InnerProduct"; + } + + const char* name_str() const + { + return "fcclip"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + float min = -FLT_MAX; + float max = FLT_MAX; + if (captured_params.find("op_1.0") != captured_params.end()) + { + min = captured_params.at("op_1.0").f; + } + if (captured_params.find("op_1.1") != captured_params.end()) + { + max = captured_params.at("op_1.1").f; + } + + op->params["9"] = 3; + op->params["10"] = Parameter{min, max}; + } +}; + +class fuse_innerproduct_sigmoid_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +InnerProduct op_0 1 1 input a %*=%* +Sigmoid op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InnerProduct"; + } + + const char* name_str() const + { + return "fcsigmoid"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 4; + } +}; + +class fuse_innerproduct_mish_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +InnerProduct op_0 1 1 input a %*=%* +Mish op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InnerProduct"; + } + + const char* name_str() const + { + return "fcmish"; + } + + bool match_captured_params(const std::map& captured_params) const + { + return captured_params.find("op_0.9") == captured_params.end(); + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + for (const auto& p : captured_params) + { + const std::string& pkey = p.first; + const Parameter& pp = p.second; + + if (pkey.substr(0, 5) == "op_0.") + op->params[pkey.substr(5)] = pp; + } + + for (const auto& a : captured_attrs) + { + const std::string& akey = a.first; + const Attribute& ap = a.second; + + if (akey.substr(0, 5) == "op_0.") + op->attrs[akey.substr(5)] = ap; + } + + op->params["9"] = 5; + } +}; + +void fuse_innerproduct_activation(Graph& graph) +{ + fuse_innerproduct_relu_pass a; + fuse_innerproduct_clip_pass b; + fuse_innerproduct_sigmoid_pass c; + fuse_innerproduct_mish_pass d; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &c, opindex); + pnnx_graph_rewrite(graph, &d, opindex); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/fuse_innerproduct_activation.h b/tools/pnnx/src/pass_ncnn/fuse_innerproduct_activation.h new file mode 100644 index 000000000000..8fa626cef8fd --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/fuse_innerproduct_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace ncnn { + +void fuse_innerproduct_activation(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/insert_split.cpp b/tools/pnnx/src/pass_ncnn/insert_split.cpp new file mode 100644 index 000000000000..3a21eace008e --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/insert_split.cpp @@ -0,0 +1,80 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "insert_split.h" +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void insert_split(Graph& graph) +{ + int opindex = 0; + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + for (auto& x : op->outputs) + { + if (x->consumers.size() <= 1) + continue; + + matched = true; + + // insert split + Operator* split = graph.new_operator_before("Split", std::string("splitncnn_") + std::to_string(opindex++), graph.ops[i + 1]); + + split->inputs.push_back(x); + + for (size_t j = 0; j < x->consumers.size(); j++) + { + Operator* op2 = x->consumers[j]; + + Operand* operand = graph.new_operand(x->name + "_" + std::to_string(j)); + operand->producer = split; + operand->consumers.push_back(op2); + + split->outputs.push_back(operand); + + for (size_t k = 0; k < op2->inputs.size(); k++) + { + if (op2->inputs[k] == x) + { + op2->inputs[k] = operand; + } + } + } + + x->consumers.clear(); + x->consumers.push_back(split); + + break; + } + + if (matched) + break; + } + + if (!matched) + break; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/insert_split.h b/tools/pnnx/src/pass_ncnn/insert_split.h new file mode 100644 index 000000000000..fc19b2dd93e6 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/insert_split.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void insert_split(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_AdaptiveAvgPool2d.cpp b/tools/pnnx/src/pass_ncnn/nn_AdaptiveAvgPool2d.cpp new file mode 100644 index 000000000000..869f6020de4d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_AdaptiveAvgPool2d.cpp @@ -0,0 +1,89 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_AdaptiveAvgPool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.AdaptiveAvgPool2d op_0 1 1 input out output_size=(1,1) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "gap"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 1; + op->params["4"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_AdaptiveAvgPool2d, 20) + +class nn_AdaptiveAvgPool2d_n : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.AdaptiveAvgPool2d op_0 1 1 input out output_size=%output_size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "aap"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 1; + op->params["7"] = 1; + op->params["8"] = captured_params.at("output_size").ai[1]; + op->params["18"] = captured_params.at("output_size").ai[0]; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_AdaptiveAvgPool2d_n, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_AdaptiveMaxPool2d.cpp b/tools/pnnx/src/pass_ncnn/nn_AdaptiveMaxPool2d.cpp new file mode 100644 index 000000000000..33ba4df01945 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_AdaptiveMaxPool2d.cpp @@ -0,0 +1,89 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_AdaptiveMaxPool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.AdaptiveMaxPool2d op_0 1 1 input out output_size=(1,1) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "gmp"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 0; + op->params["4"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_AdaptiveMaxPool2d, 20) + +class nn_AdaptiveMaxPool2d_n : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.AdaptiveMaxPool2d op_0 1 1 input out output_size=%output_size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "amp"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 0; + op->params["7"] = 1; + op->params["8"] = captured_params.at("output_size").ai[1]; + op->params["18"] = captured_params.at("output_size").ai[0]; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_AdaptiveMaxPool2d_n, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_AvgPool2d.cpp b/tools/pnnx/src/pass_ncnn/nn_AvgPool2d.cpp new file mode 100644 index 000000000000..212e24129c07 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_AvgPool2d.cpp @@ -0,0 +1,68 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_AvgPool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.AvgPool2d op_0 1 1 input out kernel_size=%kernel_size stride=%stride padding=%padding ceil_mode=%ceil_mode count_include_pad=%count_include_pad divisor_override=%divisor_override +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "avgpool2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.at("divisor_override").type != 0) + { + fprintf(stderr, "unsupported avgpool2d divisor_override\n"); + return; + } + + op->params["0"] = 1; + op->params["1"] = captured_params.at("kernel_size").ai[1]; + op->params["11"] = captured_params.at("kernel_size").ai[0]; + op->params["2"] = captured_params.at("stride").ai[1]; + op->params["12"] = captured_params.at("stride").ai[0]; + op->params["3"] = captured_params.at("padding").ai[1]; + op->params["13"] = captured_params.at("padding").ai[0]; + op->params["5"] = captured_params.at("ceil_mode").b ? 0 : 1; + op->params["6"] = captured_params.at("count_include_pad").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_AvgPool2d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_BatchNorm1d.cpp b/tools/pnnx/src/pass_ncnn/nn_BatchNorm1d.cpp new file mode 100644 index 000000000000..a0273e8f6841 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_BatchNorm1d.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_BatchNorm1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.BatchNorm1d op_0 1 1 input out affine=%affine eps=%eps num_features=%num_features @running_mean @running_var @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "BatchNorm"; + } + + const char* name_str() const + { + return "bn"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("num_features"); + op->params["1"] = captured_params.at("eps"); + + op->attrs["1"] = captured_attrs.at("op_0.running_mean"); + op->attrs["2"] = captured_attrs.at("op_0.running_var"); + + if (captured_params.at("affine").b) + { + op->attrs["0"] = captured_attrs.at("op_0.weight"); + op->attrs["3"] = captured_attrs.at("op_0.bias"); + } + else + { + const int num_features = captured_params.at("num_features").i; + std::vector weight(num_features, 1.f); + std::vector bias(num_features, 0.f); + op->attrs["0"] = Attribute({num_features}, weight); + op->attrs["3"] = Attribute({num_features}, bias); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_BatchNorm1d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_BatchNorm2d.cpp b/tools/pnnx/src/pass_ncnn/nn_BatchNorm2d.cpp new file mode 100644 index 000000000000..7dbb8cf9d048 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_BatchNorm2d.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_BatchNorm2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.BatchNorm2d op_0 1 1 input out affine=%affine eps=%eps num_features=%num_features @running_mean @running_var @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "BatchNorm"; + } + + const char* name_str() const + { + return "bn"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("num_features"); + op->params["1"] = captured_params.at("eps"); + + op->attrs["1"] = captured_attrs.at("op_0.running_mean"); + op->attrs["2"] = captured_attrs.at("op_0.running_var"); + + if (captured_params.at("affine").b) + { + op->attrs["0"] = captured_attrs.at("op_0.weight"); + op->attrs["3"] = captured_attrs.at("op_0.bias"); + } + else + { + const int num_features = captured_params.at("num_features").i; + std::vector weight(num_features, 1.f); + std::vector bias(num_features, 0.f); + op->attrs["0"] = Attribute({num_features}, weight); + op->attrs["3"] = Attribute({num_features}, bias); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_BatchNorm2d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ChannelShuffle.cpp b/tools/pnnx/src/pass_ncnn/nn_ChannelShuffle.cpp new file mode 100644 index 000000000000..ae3ca75074ff --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ChannelShuffle.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ChannelShuffle : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ChannelShuffle op_0 1 1 input out groups=%groups +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ShuffleChannel"; + } + + const char* name_str() const + { + return "channelshuffle"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("groups"); + op->params["1"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ChannelShuffle, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ConstantPad1d.cpp b/tools/pnnx/src/pass_ncnn/nn_ConstantPad1d.cpp new file mode 100644 index 000000000000..fcdecc6f6254 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ConstantPad1d.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ConstantPad1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConstantPad1d op_0 1 1 input out padding=%padding value=%value +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "constpad1d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + float pad_value = 0.f; + if (captured_params.at("value").type == 2) + pad_value = captured_params.at("value").i; + if (captured_params.at("value").type == 3) + pad_value = captured_params.at("value").f; + + op->params["0"] = 0; + op->params["1"] = 0; + op->params["2"] = captured_params.at("padding").ai[0]; + op->params["3"] = captured_params.at("padding").ai[1]; + op->params["4"] = 0; + op->params["5"] = pad_value; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ConstantPad1d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ConstantPad2d.cpp b/tools/pnnx/src/pass_ncnn/nn_ConstantPad2d.cpp new file mode 100644 index 000000000000..0baec2d0ce8f --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ConstantPad2d.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ConstantPad2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConstantPad2d op_0 1 1 input out padding=%padding value=%value +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "constpad"; + } + + void write(Operator* op, const std::map& captured_params) const + { + float pad_value = 0.f; + if (captured_params.at("value").type == 2) + pad_value = captured_params.at("value").i; + if (captured_params.at("value").type == 3) + pad_value = captured_params.at("value").f; + + op->params["0"] = captured_params.at("padding").ai[2]; + op->params["1"] = captured_params.at("padding").ai[3]; + op->params["2"] = captured_params.at("padding").ai[0]; + op->params["3"] = captured_params.at("padding").ai[1]; + op->params["4"] = 0; + op->params["5"] = pad_value; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ConstantPad2d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp b/tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp new file mode 100644 index 000000000000..c60e7b74a42d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp @@ -0,0 +1,129 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Conv1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution1D"; + } + + const char* name_str() const + { + return "conv1d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("out_channels"); + op->params["1"] = captured_params.at("kernel_size").ai[0]; + op->params["2"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[0]; + if (captured_params.at("padding").type == 4) + { + if (captured_params.at("padding").s == "same") + op->params["4"] = -233; + else if (captured_params.at("padding").s == "valid") + op->params["4"] = 0; + } + else + { + op->params["4"] = captured_params.at("padding").ai[0]; + } + op->params["5"] = captured_params.at("bias").b ? 1 : 0; + op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + op->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +class nn_Conv1d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise1D"; + } + + const char* name_str() const + { + return "convdw1d"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("out_channels"); + op->params["1"] = captured_params.at("kernel_size").ai[0]; + op->params["2"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[0]; + if (captured_params.at("padding").type == 4) + { + if (captured_params.at("padding").s == "same") + op->params["4"] = -233; + else if (captured_params.at("padding").s == "valid") + op->params["4"] = 0; + } + else + { + op->params["4"] = captured_params.at("padding").ai[0]; + } + op->params["5"] = captured_params.at("bias").b ? 1 : 0; + op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float)); + op->params["7"] = captured_params.at("groups"); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + op->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d, 20) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d_1, 21) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Conv2d.cpp b/tools/pnnx/src/pass_ncnn/nn_Conv2d.cpp new file mode 100644 index 000000000000..7c80c2ccfd4c --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Conv2d.cpp @@ -0,0 +1,137 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Conv2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Convolution"; + } + + const char* name_str() const + { + return "conv"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("out_channels"); + op->params["1"] = captured_params.at("kernel_size").ai[1]; + op->params["11"] = captured_params.at("kernel_size").ai[0]; + op->params["2"] = captured_params.at("dilation").ai[1]; + op->params["12"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[1]; + op->params["13"] = captured_params.at("stride").ai[0]; + if (captured_params.at("padding").type == 4) + { + if (captured_params.at("padding").s == "same") + op->params["4"] = -233; + else if (captured_params.at("padding").s == "valid") + op->params["4"] = 0; + } + else + { + op->params["4"] = captured_params.at("padding").ai[1]; + op->params["14"] = captured_params.at("padding").ai[0]; + } + op->params["5"] = captured_params.at("bias").b ? 1 : 0; + op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + op->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +class nn_Conv2d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ConvolutionDepthWise"; + } + + const char* name_str() const + { + return "convdw"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("out_channels"); + op->params["1"] = captured_params.at("kernel_size").ai[1]; + op->params["11"] = captured_params.at("kernel_size").ai[0]; + op->params["2"] = captured_params.at("dilation").ai[1]; + op->params["12"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[1]; + op->params["13"] = captured_params.at("stride").ai[0]; + if (captured_params.at("padding").type == 4) + { + if (captured_params.at("padding").s == "same") + op->params["4"] = -233; + else if (captured_params.at("padding").s == "valid") + op->params["4"] = 0; + } + else + { + op->params["4"] = captured_params.at("padding").ai[1]; + op->params["14"] = captured_params.at("padding").ai[0]; + } + op->params["5"] = captured_params.at("bias").b ? 1 : 0; + op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float)); + op->params["7"] = captured_params.at("groups"); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + op->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv2d, 20) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv2d_1, 21) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ConvTranspose2d.cpp b/tools/pnnx/src/pass_ncnn/nn_ConvTranspose2d.cpp new file mode 100644 index 000000000000..832aadf70baa --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ConvTranspose2d.cpp @@ -0,0 +1,181 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ConvTranspose2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConvTranspose2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=1 bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Deconvolution"; + } + + const char* name_str() const + { + return "deconv"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("out_channels"); + op->params["1"] = captured_params.at("kernel_size").ai[1]; + op->params["11"] = captured_params.at("kernel_size").ai[0]; + op->params["2"] = captured_params.at("dilation").ai[1]; + op->params["12"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[1]; + op->params["13"] = captured_params.at("stride").ai[0]; + op->params["4"] = captured_params.at("padding").ai[1]; + op->params["14"] = captured_params.at("padding").ai[0]; + op->params["18"] = captured_params.at("output_padding").ai[1]; + op->params["19"] = captured_params.at("output_padding").ai[0]; + op->params["5"] = captured_params.at("bias").b ? 1 : 0; + op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float)); + + // transpose inch-outch-kh-kw to outch-inch-kh-kw + const int inch = captured_params.at("in_channels").i; + const int outch = captured_params.at("out_channels").i; + const int kh = captured_params.at("kernel_size").ai[0]; + const int kw = captured_params.at("kernel_size").ai[1]; + std::vector new_weight; + { + const float* w = (const float*)captured_attrs.at("op_0.weight").data.data(); + + new_weight.resize(outch * inch * kh * kw); + float* w2 = (float*)new_weight.data(); + const int maxk = kh * kw; + + // reorder weight from inch-outch to outch-inch + for (int i = 0; i < outch; i++) + { + for (int j = 0; j < inch; j++) + { + for (int k = 0; k < maxk; k++) + { + w2[(i * inch + j) * maxk + k] = w[(j * outch + i) * maxk + k]; + } + } + } + } + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = Attribute({outch, inch, kh, kw}, new_weight); + if (captured_params.at("bias").b) + op->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +class nn_ConvTranspose2d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ConvTranspose2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding=%padding dilation=%dilation output_padding=%output_padding groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "DeconvolutionDepthWise"; + } + + const char* name_str() const + { + return "deconvdw"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("out_channels"); + op->params["1"] = captured_params.at("kernel_size").ai[1]; + op->params["11"] = captured_params.at("kernel_size").ai[0]; + op->params["2"] = captured_params.at("dilation").ai[1]; + op->params["12"] = captured_params.at("dilation").ai[0]; + op->params["3"] = captured_params.at("stride").ai[1]; + op->params["13"] = captured_params.at("stride").ai[0]; + op->params["4"] = captured_params.at("padding").ai[1]; + op->params["14"] = captured_params.at("padding").ai[0]; + op->params["18"] = captured_params.at("output_padding").ai[1]; + op->params["19"] = captured_params.at("output_padding").ai[0]; + op->params["5"] = captured_params.at("bias").b ? 1 : 0; + op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float)); + op->params["7"] = captured_params.at("groups"); + + // transpose group-inch/group-outch/group-kh-kw to group-outch/group-inch/group-kh-kw + const int inch = captured_params.at("in_channels").i; + const int outch = captured_params.at("out_channels").i; + const int groups = captured_params.at("groups").i; + const int kh = captured_params.at("kernel_size").ai[0]; + const int kw = captured_params.at("kernel_size").ai[1]; + std::vector new_weight; + { + const float* w = (const float*)captured_attrs.at("op_0.weight").data.data(); + + new_weight.resize(outch / groups * inch * kh * kw); + float* w2 = (float*)new_weight.data(); + const int outch_g = outch / groups; + const int inch_g = inch / groups; + const int maxk = kh * kw; + + for (int g = 0; g < groups; g++) + { + // reorder weight from inch-outch to outch-inch + float* wg2 = w2 + g * outch_g * inch_g * maxk; + const float* wg = w + g * inch_g * outch_g * maxk; + for (int i = 0; i < outch_g; i++) + { + for (int j = 0; j < inch_g; j++) + { + for (int k = 0; k < maxk; k++) + { + wg2[(i * inch_g + j) * maxk + k] = wg[(j * outch_g + i) * maxk + k]; + } + } + } + } + } + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = Attribute({outch / groups, inch, kh, kw}, new_weight); + if (captured_params.at("bias").b) + op->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ConvTranspose2d, 20) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ConvTranspose2d_1, 21) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Dropout.cpp b/tools/pnnx/src/pass_ncnn/nn_Dropout.cpp new file mode 100644 index 000000000000..ca295699bace --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Dropout.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Dropout : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Dropout op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Noop"; + } + + const char* name_str() const + { + return "dropout"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Dropout, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ELU.cpp b/tools/pnnx/src/pass_ncnn/nn_ELU.cpp new file mode 100644 index 000000000000..b54592f2610f --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ELU.cpp @@ -0,0 +1,54 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ELU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ELU op_0 1 1 input out alpha=%alpha +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ELU"; + } + + const char* name_str() const + { + return "elu"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("alpha"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ELU, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Embedding.cpp b/tools/pnnx/src/pass_ncnn/nn_Embedding.cpp new file mode 100644 index 000000000000..f8885c0250b1 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Embedding.cpp @@ -0,0 +1,61 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Embedding : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Embedding op_0 1 1 input out embedding_dim=%embedding_dim num_embeddings=%num_embeddings sparse=False @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Embed"; + } + + const char* name_str() const + { + return "embed"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("embedding_dim"); + op->params["1"] = captured_params.at("num_embeddings"); + op->params["2"] = 0; + op->params["3"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Embedding, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_GELU.cpp b/tools/pnnx/src/pass_ncnn/nn_GELU.cpp new file mode 100644 index 000000000000..bec078bbeb25 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_GELU.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_GELU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.GELU op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GELU"; + } + + const char* name_str() const + { + return "gelu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_GELU, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_GRU.cpp b/tools/pnnx/src/pass_ncnn/nn_GRU.cpp new file mode 100644 index 000000000000..7e4231e6a510 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_GRU.cpp @@ -0,0 +1,187 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" +#include + +namespace pnnx { + +namespace ncnn { + +class nn_GRU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 3 +pnnx.Input input 0 1 input +nn.GRU op_0 1 2 input out out_hidden input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 2 0 out out_hidden +)PNNXIR"; + } + + const char* type_str() const + { + return "GRU"; + } + + const char* name_str() const + { + return "gru"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const bool bidirectional = captured_params.at("bidirectional").b; + const int num_directions = bidirectional ? 2 : 1; + const int num_output = captured_params.at("hidden_size").i; + const int input_size = captured_params.at("input_size").i; + + int weight_data_size = num_directions * num_output * input_size * 3; + + op->params["0"] = num_output; + op->params["1"] = weight_data_size; + op->params["2"] = bidirectional ? 2 : 0; + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + + // RUN-hidden-input_size + { + op->attrs["1"] = captured_attrs.at("op_0.weight_ih_l0"); + if (bidirectional) + op->attrs["2"] = captured_attrs.at("op_0.weight_ih_l0_reverse"); + } + + op->attrs["3"] = Attribute(); + op->attrs["3"].data = {0, 0, 0, 0}; + if (captured_params.at("bias").b) + { + // reduce bias_ih and bias_hh + std::vector new_bias; + { + const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data(); + const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data(); + + new_bias.resize(4 * num_output); + float* bias = (float*)new_bias.data(); + for (int i = 0; i < 2 * num_output; i++) + { + bias[i] = bias_ih[i] + bias_hh[i]; + } + memcpy(bias + num_output * 2, bias_ih + num_output * 2, num_output * sizeof(float)); + memcpy(bias + num_output * 3, bias_hh + num_output * 2, num_output * sizeof(float)); + } + + op->attrs["4"] = Attribute({4, num_output}, new_bias); + + if (bidirectional) + { + std::vector new_bias_reverse; + { + const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data(); + const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data(); + + new_bias_reverse.resize(4 * num_output); + float* bias = (float*)new_bias_reverse.data(); + for (int i = 0; i < 2 * num_output; i++) + { + bias[i] = bias_ih[i] + bias_hh[i]; + } + memcpy(bias + num_output * 2, bias_ih + num_output * 2, num_output * sizeof(float)); + memcpy(bias + num_output * 3, bias_hh + num_output * 2, num_output * sizeof(float)); + } + + op->attrs["5"] = Attribute({4, num_output}, new_bias_reverse); + } + } + else + { + std::vector bias(4 * num_output, 0.f); + op->attrs["4"] = Attribute({4, num_output}, bias); + + if (bidirectional) + { + op->attrs["5"] = Attribute({4, num_output}, bias); + } + } + + op->attrs["6"] = Attribute(); + op->attrs["6"].data = {0, 0, 0, 0}; + + // RUN-hidden-hidden + { + op->attrs["7"] = captured_attrs.at("op_0.weight_hh_l0"); + if (bidirectional) + op->attrs["8"] = captured_attrs.at("op_0.weight_hh_l0_reverse"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_GRU, 20) + +class nn_GRU_1 : public nn_GRU +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 4 +pnnx.Input input 0 1 input +pnnx.Input in_hidden 0 1 in_hidden +nn.GRU op_0 2 2 input in_hidden out out_hidden input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 2 0 out out_hidden +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_GRU_1, 20) + +class nn_GRU_2 : public nn_GRU +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.GRU op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_GRU_2, 20) + +class nn_GRU_3 : public nn_GRU +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Input in_hidden 0 1 in_hidden +nn.GRU op_0 2 1 input in_hidden out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_GRU_3, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_GroupNorm.cpp b/tools/pnnx/src/pass_ncnn/nn_GroupNorm.cpp new file mode 100644 index 000000000000..3414de2dcbe8 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_GroupNorm.cpp @@ -0,0 +1,63 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_GroupNorm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.GroupNorm op_0 1 1 input out num_groups=%num_groups num_channels=%num_channels eps=%eps affine=%affine @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GroupNorm"; + } + + const char* name_str() const + { + return "gn"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("num_groups"); + op->params["1"] = captured_params.at("num_channels"); + op->params["2"] = captured_params.at("eps"); + op->params["3"] = captured_params.at("affine").b ? 1 : 0; + + if (captured_params.at("affine").b) + { + op->attrs["0"] = captured_attrs.at("op_0.weight"); + op->attrs["1"] = captured_attrs.at("op_0.bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_GroupNorm, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Hardsigmoid.cpp b/tools/pnnx/src/pass_ncnn/nn_Hardsigmoid.cpp new file mode 100644 index 000000000000..f93da3a1f3cd --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Hardsigmoid.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Hardsigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Hardsigmoid op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "HardSigmoid"; + } + + const char* name_str() const + { + return "hsigmoid"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 1.f / 6; + op->params["1"] = 0.5f; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Hardsigmoid, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Hardswish.cpp b/tools/pnnx/src/pass_ncnn/nn_Hardswish.cpp new file mode 100644 index 000000000000..a2bfc5c4bb2b --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Hardswish.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Hardswish : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Hardswish op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "HardSwish"; + } + + const char* name_str() const + { + return "hswish"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 1.f / 6; + op->params["1"] = 0.5f; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Hardswish, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Hardtanh.cpp b/tools/pnnx/src/pass_ncnn/nn_Hardtanh.cpp new file mode 100644 index 000000000000..2563a1291b3d --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Hardtanh.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Hardtanh : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Hardtanh op_0 1 1 input out min_val=%0 max_val=%1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Clip"; + } + + const char* name_str() const + { + return "htanh"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Hardtanh, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_InstanceNorm2d.cpp b/tools/pnnx/src/pass_ncnn/nn_InstanceNorm2d.cpp new file mode 100644 index 000000000000..828cc1fc4cd2 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_InstanceNorm2d.cpp @@ -0,0 +1,62 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_InstanceNorm2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.InstanceNorm2d op_0 1 1 input out num_features=%num_features eps=%eps affine=%affine track_running_stats=%track_running_stats @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InstanceNorm"; + } + + const char* name_str() const + { + return "in"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("num_features"); + op->params["1"] = captured_params.at("eps"); + op->params["2"] = captured_params.at("affine").b ? 1 : 0; + + if (captured_params.at("affine").b) + { + op->attrs["0"] = captured_attrs.at("op_0.weight"); + op->attrs["1"] = captured_attrs.at("op_0.bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_InstanceNorm2d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_LSTM.cpp b/tools/pnnx/src/pass_ncnn/nn_LSTM.cpp new file mode 100644 index 000000000000..899ee1fac906 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_LSTM.cpp @@ -0,0 +1,324 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" +#include + +namespace pnnx { + +namespace ncnn { + +class nn_LSTM : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 4 +pnnx.Input input 0 1 input +nn.LSTM op_0 1 3 input out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 3 0 out out_hidden out_cell +)PNNXIR"; + } + + const char* type_str() const + { + return "LSTM"; + } + + const char* name_str() const + { + return "lstm"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const bool bidirectional = captured_params.at("bidirectional").b; + const int num_directions = bidirectional ? 2 : 1; + const int num_output = captured_params.at("hidden_size").i; + const int input_size = captured_params.at("input_size").i; + + int weight_data_size = num_directions * num_output * input_size * 4; + + op->params["0"] = num_output; + op->params["1"] = weight_data_size; + op->params["2"] = bidirectional ? 2 : 0; + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + + // reorder IFGO-hidden-input_size to IFOG-hidden-input_size + { + std::vector new_weight_ih; + { + const int weight_data_size_g = num_output * input_size; + + const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0").data.data(); + const float* iptr = weight_ih; + const float* fptr = weight_ih + weight_data_size_g; + const float* gptr = weight_ih + weight_data_size_g * 2; + const float* optr = weight_ih + weight_data_size_g * 3; + + new_weight_ih.resize(4 * num_output * input_size); + float* weight = (float*)new_weight_ih.data(); + float* w_iptr = weight; + float* w_fptr = weight + weight_data_size_g; + float* w_optr = weight + weight_data_size_g * 2; + float* w_gptr = weight + weight_data_size_g * 3; + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + } + op->attrs["1"] = Attribute({4, num_output, input_size}, new_weight_ih); + + if (bidirectional) + { + std::vector new_weight_ih_reverse; + { + const int weight_data_size_g = num_output * input_size; + + const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0_reverse").data.data(); + const float* iptr = weight_ih; + const float* fptr = weight_ih + weight_data_size_g; + const float* gptr = weight_ih + weight_data_size_g * 2; + const float* optr = weight_ih + weight_data_size_g * 3; + + new_weight_ih_reverse.resize(4 * num_output * input_size); + float* weight = (float*)new_weight_ih_reverse.data(); + float* w_iptr = weight; + float* w_fptr = weight + weight_data_size_g; + float* w_optr = weight + weight_data_size_g * 2; + float* w_gptr = weight + weight_data_size_g * 3; + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + } + op->attrs["2"] = Attribute({4, num_output, input_size}, new_weight_ih_reverse); + } + } + + op->attrs["3"] = Attribute(); + op->attrs["3"].data = {0, 0, 0, 0}; + if (captured_params.at("bias").b) + { + // reduce bias_ih and bias_hh + // reorder IFGO-hidden to IFOG-hidden + std::vector new_bias; + { + const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data(); + const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data(); + const float* bias_ih_iptr = bias_ih; + const float* bias_ih_fptr = bias_ih + num_output; + const float* bias_ih_gptr = bias_ih + num_output * 2; + const float* bias_ih_optr = bias_ih + num_output * 3; + const float* bias_hh_iptr = bias_hh; + const float* bias_hh_fptr = bias_hh + num_output; + const float* bias_hh_gptr = bias_hh + num_output * 2; + const float* bias_hh_optr = bias_hh + num_output * 3; + + new_bias.resize(4 * num_output); + float* bias = (float*)new_bias.data(); + float* b_iptr = bias; + float* b_fptr = bias + num_output; + float* b_optr = bias + num_output * 2; + float* b_gptr = bias + num_output * 3; + for (int i = 0; i < num_output; i++) + { + b_iptr[i] = bias_ih_iptr[i] + bias_hh_iptr[i]; + } + for (int i = 0; i < num_output; i++) + { + b_fptr[i] = bias_ih_fptr[i] + bias_hh_fptr[i]; + } + for (int i = 0; i < num_output; i++) + { + b_optr[i] = bias_ih_optr[i] + bias_hh_optr[i]; + } + for (int i = 0; i < num_output; i++) + { + b_gptr[i] = bias_ih_gptr[i] + bias_hh_gptr[i]; + } + } + + op->attrs["4"] = Attribute({4, num_output}, new_bias); + + if (bidirectional) + { + std::vector new_bias_reverse; + { + const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data(); + const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data(); + const float* bias_ih_iptr = bias_ih; + const float* bias_ih_fptr = bias_ih + num_output; + const float* bias_ih_gptr = bias_ih + num_output * 2; + const float* bias_ih_optr = bias_ih + num_output * 3; + const float* bias_hh_iptr = bias_hh; + const float* bias_hh_fptr = bias_hh + num_output; + const float* bias_hh_gptr = bias_hh + num_output * 2; + const float* bias_hh_optr = bias_hh + num_output * 3; + + new_bias_reverse.resize(4 * num_output); + float* bias = (float*)new_bias_reverse.data(); + float* b_iptr = bias; + float* b_fptr = bias + num_output; + float* b_optr = bias + num_output * 2; + float* b_gptr = bias + num_output * 3; + for (int i = 0; i < num_output; i++) + { + b_iptr[i] = bias_ih_iptr[i] + bias_hh_iptr[i]; + } + for (int i = 0; i < num_output; i++) + { + b_fptr[i] = bias_ih_fptr[i] + bias_hh_fptr[i]; + } + for (int i = 0; i < num_output; i++) + { + b_optr[i] = bias_ih_optr[i] + bias_hh_optr[i]; + } + for (int i = 0; i < num_output; i++) + { + b_gptr[i] = bias_ih_gptr[i] + bias_hh_gptr[i]; + } + } + + op->attrs["5"] = Attribute({4, num_output}, new_bias_reverse); + } + } + else + { + std::vector bias(4 * num_output, 0.f); + op->attrs["4"] = Attribute({4, num_output}, bias); + + if (bidirectional) + { + op->attrs["5"] = Attribute({4, num_output}, bias); + } + } + + op->attrs["6"] = Attribute(); + op->attrs["6"].data = {0, 0, 0, 0}; + + // reorder IFGO-hidden-hidden to IFOG-hidden-hidden + { + std::vector new_weight_hh; + { + const int weight_data_size_g = num_output * num_output; + + const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0").data.data(); + const float* iptr = weight_hh; + const float* fptr = weight_hh + weight_data_size_g; + const float* gptr = weight_hh + weight_data_size_g * 2; + const float* optr = weight_hh + weight_data_size_g * 3; + + new_weight_hh.resize(4 * num_output * num_output); + float* weight = (float*)new_weight_hh.data(); + float* w_iptr = weight; + float* w_fptr = weight + weight_data_size_g; + float* w_optr = weight + weight_data_size_g * 2; + float* w_gptr = weight + weight_data_size_g * 3; + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + } + op->attrs["7"] = Attribute({4, num_output, num_output}, new_weight_hh); + + if (bidirectional) + { + std::vector new_weight_hh_reverse; + { + const int weight_data_size_g = num_output * num_output; + + const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0_reverse").data.data(); + const float* iptr = weight_hh; + const float* fptr = weight_hh + weight_data_size_g; + const float* gptr = weight_hh + weight_data_size_g * 2; + const float* optr = weight_hh + weight_data_size_g * 3; + + new_weight_hh_reverse.resize(4 * num_output * num_output); + float* weight = (float*)new_weight_hh_reverse.data(); + float* w_iptr = weight; + float* w_fptr = weight + weight_data_size_g; + float* w_optr = weight + weight_data_size_g * 2; + float* w_gptr = weight + weight_data_size_g * 3; + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + } + op->attrs["8"] = Attribute({4, num_output, num_output}, new_weight_hh_reverse); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LSTM, 20) + +class nn_LSTM_1 : public nn_LSTM +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 6 +pnnx.Input input 0 1 input +pnnx.Input in_hidden 0 1 in_hidden +pnnx.Input in_hidden 0 1 in_cell +nn.LSTM op_0 3 3 input in_hidden in_cell out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 3 0 out out_hidden out_cell +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LSTM_1, 20) + +class nn_LSTM_2 : public nn_LSTM +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.LSTM op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LSTM_2, 20) + +class nn_LSTM_3 : public nn_LSTM +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Input in_hidden 0 1 in_hidden +pnnx.Input in_hidden 0 1 in_cell +nn.LSTM op_0 3 1 input in_hidden in_cell out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LSTM_3, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_LayerNorm.cpp b/tools/pnnx/src/pass_ncnn/nn_LayerNorm.cpp new file mode 100644 index 000000000000..a26d70e20b7a --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_LayerNorm.cpp @@ -0,0 +1,69 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_LayerNorm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.LayerNorm op_0 1 1 input out normalized_shape=%normalized_shape eps=%eps elementwise_affine=%elementwise_affine @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "LayerNorm"; + } + + const char* name_str() const + { + return "ln"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector& normalized_shape = captured_params.at("normalized_shape").ai; + int affine_size = normalized_shape[0]; + for (size_t i = 1; i < normalized_shape.size(); i++) + { + affine_size *= normalized_shape[i]; + } + + op->params["0"] = affine_size; + op->params["1"] = captured_params.at("eps"); + op->params["2"] = captured_params.at("elementwise_affine").b ? 1 : 0; + + if (captured_params.at("elementwise_affine").b) + { + op->attrs["0"] = captured_attrs.at("op_0.weight"); + op->attrs["1"] = captured_attrs.at("op_0.bias"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LayerNorm, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_LeakyReLU.cpp b/tools/pnnx/src/pass_ncnn/nn_LeakyReLU.cpp new file mode 100644 index 000000000000..2251a32428ff --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_LeakyReLU.cpp @@ -0,0 +1,54 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_LeakyReLU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.LeakyReLU op_0 1 1 input out negative_slope=%negative_slope +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ReLU"; + } + + const char* name_str() const + { + return "leakyrelu"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("negative_slope"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LeakyReLU, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Linear.cpp b/tools/pnnx/src/pass_ncnn/nn_Linear.cpp new file mode 100644 index 000000000000..06b028664469 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Linear.cpp @@ -0,0 +1,62 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Linear : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Linear op_0 1 1 input out in_features=%in_features out_features=%out_features bias=%bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InnerProduct"; + } + + const char* name_str() const + { + return "linear"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("out_features"); + op->params["1"] = captured_params.at("bias").b ? 1 : 0; + op->params["2"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + op->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_LocalResponseNorm.cpp b/tools/pnnx/src/pass_ncnn/nn_LocalResponseNorm.cpp new file mode 100644 index 000000000000..1ad86da35595 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_LocalResponseNorm.cpp @@ -0,0 +1,58 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_LocalResponseNorm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.LocalResponseNorm op_0 1 1 input out size=%size alpha=%alpha beta=%beta k=%k +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "LRN"; + } + + const char* name_str() const + { + return "lrn"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 0; // region_type ACROSS_CHANNELS + op->params["1"] = captured_params.at("size"); + op->params["2"] = captured_params.at("alpha"); + op->params["3"] = captured_params.at("beta"); + op->params["4"] = captured_params.at("k"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_LocalResponseNorm, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_MaxPool2d.cpp b/tools/pnnx/src/pass_ncnn/nn_MaxPool2d.cpp new file mode 100644 index 000000000000..89cc53ee2d0e --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_MaxPool2d.cpp @@ -0,0 +1,61 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_MaxPool2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MaxPool2d op_0 1 1 input out kernel_size=%kernel_size stride=%stride dilation=(1,1) padding=%padding ceil_mode=%ceil_mode return_indices=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "maxpool2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 0; + op->params["1"] = captured_params.at("kernel_size").ai[1]; + op->params["11"] = captured_params.at("kernel_size").ai[0]; + op->params["2"] = captured_params.at("stride").ai[1]; + op->params["12"] = captured_params.at("stride").ai[0]; + op->params["3"] = captured_params.at("padding").ai[1]; + op->params["13"] = captured_params.at("padding").ai[0]; + op->params["5"] = captured_params.at("ceil_mode").b ? 0 : 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_MaxPool2d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Mish.cpp b/tools/pnnx/src/pass_ncnn/nn_Mish.cpp new file mode 100644 index 000000000000..e1d6fb1893d5 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Mish.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Mish : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Mish op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Mish"; + } + + const char* name_str() const + { + return "mish"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Mish, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_MultiheadAttention.cpp b/tools/pnnx/src/pass_ncnn/nn_MultiheadAttention.cpp new file mode 100644 index 000000000000..bcc9407f6053 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_MultiheadAttention.cpp @@ -0,0 +1,133 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +#include +#include + +namespace pnnx { + +namespace ncnn { + +class nn_MultiheadAttention : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MultiheadAttention op_0 1 1 input out num_heads=%num_heads batch_first=%batch_first add_zero_attn=%add_zero_attn embed_dim=%embed_dim bias=%bias add_bias_kv=%add_bias_kv @in_proj_weight @in_proj_bias @bias_k @bias_v @out_proj.weight @out_proj.bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "MultiHeadAttention"; + } + + const char* name_str() const + { + return "attention"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("embed_dim"); + op->params["1"] = captured_params.at("num_heads"); + + if (captured_params.at("add_bias_kv").b) + { + fprintf(stderr, "MultiheadAttention add_bias_kv=True not supported\n"); + } + + const int embed_dim = captured_params.at("embed_dim").i; + + // split in_proj_weight and in_proj_bias into q k v + std::vector q_weight(embed_dim * embed_dim); + std::vector q_bias(embed_dim); + std::vector k_weight(embed_dim * embed_dim); + std::vector k_bias(embed_dim); + std::vector v_weight(embed_dim * embed_dim); + std::vector v_bias(embed_dim); + { + // qkv - embed_dim - embed_dim + const float* wptr = (const float*)captured_attrs.at("op_0.in_proj_weight").data.data(); + // qkv - embed_dim + const float* bptr = (const float*)captured_attrs.at("op_0.in_proj_bias").data.data(); + + { + memcpy(q_weight.data(), wptr, embed_dim * embed_dim * sizeof(float)); + memcpy(q_bias.data(), bptr, embed_dim * sizeof(float)); + wptr += embed_dim * embed_dim; + bptr += embed_dim; + } + + { + memcpy(k_weight.data(), wptr, embed_dim * embed_dim * sizeof(float)); + memcpy(k_bias.data(), bptr, embed_dim * sizeof(float)); + wptr += embed_dim * embed_dim; + bptr += embed_dim; + } + + { + memcpy(v_weight.data(), wptr, embed_dim * embed_dim * sizeof(float)); + memcpy(v_bias.data(), bptr, embed_dim * sizeof(float)); + } + } + + op->params["2"] = embed_dim * embed_dim; + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = Attribute({embed_dim, embed_dim}, q_weight); + op->attrs["2"] = Attribute({embed_dim}, q_bias); + op->attrs["3"] = Attribute(); + op->attrs["3"].data = {0, 0, 0, 0}; + op->attrs["4"] = Attribute({embed_dim, embed_dim}, k_weight); + op->attrs["5"] = Attribute({embed_dim}, k_bias); + op->attrs["6"] = Attribute(); + op->attrs["6"].data = {0, 0, 0, 0}; + op->attrs["7"] = Attribute({embed_dim, embed_dim}, v_weight); + op->attrs["8"] = Attribute({embed_dim}, v_bias); + op->attrs["9"] = Attribute(); + op->attrs["9"].data = {0, 0, 0, 0}; + op->attrs["a"] = captured_attrs.at("op_0.out_proj.weight"); + op->attrs["b"] = captured_attrs.at("op_0.out_proj.bias"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_MultiheadAttention, 20) + +class nn_MultiheadAttention_1 : public nn_MultiheadAttention +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MultiheadAttention op_0 1 1 input out num_heads=%num_heads add_zero_attn=%add_zero_attn embed_dim=%embed_dim bias=%bias add_bias_kv=%add_bias_kv @in_proj_weight @in_proj_bias @bias_k @bias_v @out_proj.weight @out_proj.bias +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_MultiheadAttention_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_PReLU.cpp b/tools/pnnx/src/pass_ncnn/nn_PReLU.cpp new file mode 100644 index 000000000000..4d8c1445886b --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_PReLU.cpp @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_PReLU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.PReLU op_0 1 1 input out num_parameters=%num_parameters @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "PReLU"; + } + + const char* name_str() const + { + return "prelu"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["0"] = captured_params.at("num_parameters"); + + op->attrs["0"] = captured_attrs.at("op_0.weight"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_PReLU, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_PixelShuffle.cpp b/tools/pnnx/src/pass_ncnn/nn_PixelShuffle.cpp new file mode 100644 index 000000000000..107c211b6aa8 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_PixelShuffle.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_PixelShuffle : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.PixelShuffle op_0 1 1 input out upscale_factor=%upscale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "PixelShuffle"; + } + + const char* name_str() const + { + return "pixelshuffle"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("upscale_factor"); + op->params["1"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_PixelShuffle, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_PixelUnshuffle.cpp b/tools/pnnx/src/pass_ncnn/nn_PixelUnshuffle.cpp new file mode 100644 index 000000000000..ff259e0a9395 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_PixelUnshuffle.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_PixelUnshuffle : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.PixelUnshuffle op_0 1 1 input out downscale_factor=%downscale_factor +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reorg"; + } + + const char* name_str() const + { + return "pixelunshuffle"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("downscale_factor"); + op->params["1"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_PixelUnshuffle, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_RNN.cpp b/tools/pnnx/src/pass_ncnn/nn_RNN.cpp new file mode 100644 index 000000000000..c9deab9d96f7 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_RNN.cpp @@ -0,0 +1,181 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_RNN : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 3 +pnnx.Input input 0 1 input +nn.RNN op_0 1 2 input out out_hidden input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 2 0 out out_hidden +)PNNXIR"; + } + + const char* type_str() const + { + return "RNN"; + } + + const char* name_str() const + { + return "rnn"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::string nonlinearity = captured_params.at("nonlinearity").s; + + if (nonlinearity != "tanh") + { + fprintf(stderr, "RNN nonlinearity=%s not supported\n", nonlinearity.c_str()); + } + + const bool bidirectional = captured_params.at("bidirectional").b; + const int num_directions = bidirectional ? 2 : 1; + const int num_output = captured_params.at("hidden_size").i; + const int input_size = captured_params.at("input_size").i; + + int weight_data_size = num_directions * num_output * input_size; + + op->params["0"] = num_output; + op->params["1"] = weight_data_size; + op->params["2"] = bidirectional ? 2 : 0; + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = captured_attrs.at("op_0.weight_ih_l0"); + if (bidirectional) + op->attrs["2"] = captured_attrs.at("op_0.weight_ih_l0_reverse"); + + op->attrs["3"] = Attribute(); + op->attrs["3"].data = {0, 0, 0, 0}; + if (captured_params.at("bias").b) + { + // reduce bias_ih and bias_hh + std::vector new_bias; + { + const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data(); + const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data(); + + new_bias.resize(num_output); + float* bias = (float*)new_bias.data(); + for (int i = 0; i < num_output; i++) + { + bias[i] = bias_ih[i] + bias_hh[i]; + } + } + + op->attrs["4"] = Attribute({num_output}, new_bias); + + if (bidirectional) + { + std::vector new_bias_reverse; + { + const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data(); + const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data(); + + new_bias_reverse.resize(num_output); + float* bias = (float*)new_bias_reverse.data(); + for (int i = 0; i < num_output; i++) + { + bias[i] = bias_ih[i] + bias_hh[i]; + } + } + + op->attrs["5"] = Attribute({num_output}, new_bias_reverse); + } + } + else + { + std::vector bias(num_output, 0.f); + op->attrs["4"] = Attribute({num_output}, bias); + + if (bidirectional) + { + op->attrs["5"] = Attribute({num_output}, bias); + } + } + + op->attrs["6"] = Attribute(); + op->attrs["6"].data = {0, 0, 0, 0}; + op->attrs["7"] = captured_attrs.at("op_0.weight_hh_l0"); + if (bidirectional) + op->attrs["8"] = captured_attrs.at("op_0.weight_hh_l0_reverse"); + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN, 20) + +class nn_RNN_1 : public nn_RNN +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 4 +pnnx.Input input 0 1 input +pnnx.Input in_hidden 0 1 in_hidden +nn.RNN op_0 2 2 input in_hidden out out_hidden input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 2 0 out out_hidden +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_1, 20) + +class nn_RNN_2 : public nn_RNN +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.RNN op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_2, 20) + +class nn_RNN_3 : public nn_RNN +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Input in_hidden 0 1 in_hidden +nn.RNN op_0 2 1 input in_hidden out input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_3, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ReLU.cpp b/tools/pnnx/src/pass_ncnn/nn_ReLU.cpp new file mode 100644 index 000000000000..dfe78c6c47cc --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ReLU.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ReLU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ReLU op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ReLU"; + } + + const char* name_str() const + { + return "relu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ReLU, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ReLU6.cpp b/tools/pnnx/src/pass_ncnn/nn_ReLU6.cpp new file mode 100644 index 000000000000..bed321e35415 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ReLU6.cpp @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ReLU6 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ReLU6 op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Clip"; + } + + const char* name_str() const + { + return "relu6"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 0.f; + op->params["1"] = 6.f; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ReLU6, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ReflectionPad1d.cpp b/tools/pnnx/src/pass_ncnn/nn_ReflectionPad1d.cpp new file mode 100644 index 000000000000..9deb31d59e93 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ReflectionPad1d.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ReflectionPad1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ReflectionPad1d op_0 1 1 input out padding=%padding +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "reflectpad1d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& padding = captured_params.at("padding").ai; + op->params["0"] = 0; + op->params["1"] = 0; + op->params["2"] = padding[0]; + op->params["3"] = padding[1]; + op->params["4"] = 2; // type + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ReflectionPad1d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ReflectionPad2d.cpp b/tools/pnnx/src/pass_ncnn/nn_ReflectionPad2d.cpp new file mode 100644 index 000000000000..c071303b07c3 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ReflectionPad2d.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ReflectionPad2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ReflectionPad2d op_0 1 1 input out padding=%padding +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "reflectpad2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& padding = captured_params.at("padding").ai; + op->params["0"] = padding[2]; + op->params["1"] = padding[3]; + op->params["2"] = padding[0]; + op->params["3"] = padding[1]; + op->params["4"] = 2; // type + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ReflectionPad2d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ReplicationPad1d.cpp b/tools/pnnx/src/pass_ncnn/nn_ReplicationPad1d.cpp new file mode 100644 index 000000000000..75c01df105ae --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ReplicationPad1d.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ReplicationPad1d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ReplicationPad1d op_0 1 1 input out padding=%padding +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "replicatepad1d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& padding = captured_params.at("padding").ai; + op->params["0"] = 0; + op->params["1"] = 0; + op->params["2"] = padding[0]; + op->params["3"] = padding[1]; + op->params["4"] = 1; // type + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ReplicationPad1d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ReplicationPad2d.cpp b/tools/pnnx/src/pass_ncnn/nn_ReplicationPad2d.cpp new file mode 100644 index 000000000000..693337f699c2 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ReplicationPad2d.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ReplicationPad2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ReplicationPad2d op_0 1 1 input out padding=%padding +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "replicatepad2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& padding = captured_params.at("padding").ai; + op->params["0"] = padding[2]; + op->params["1"] = padding[3]; + op->params["2"] = padding[0]; + op->params["3"] = padding[1]; + op->params["4"] = 1; // type + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ReplicationPad2d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_SELU.cpp b/tools/pnnx/src/pass_ncnn/nn_SELU.cpp new file mode 100644 index 000000000000..bb03092b8385 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_SELU.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_SELU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.SELU op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "SELU"; + } + + const char* name_str() const + { + return "selu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_SELU, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_SiLU.cpp b/tools/pnnx/src/pass_ncnn/nn_SiLU.cpp new file mode 100644 index 000000000000..60da16c818ed --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_SiLU.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_SiLU : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.SiLU op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Swish"; + } + + const char* name_str() const + { + return "silu"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_SiLU, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Sigmoid.cpp b/tools/pnnx/src/pass_ncnn/nn_Sigmoid.cpp new file mode 100644 index 000000000000..62639369a73f --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Sigmoid.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Sigmoid : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Sigmoid op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Sigmoid"; + } + + const char* name_str() const + { + return "sigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Sigmoid, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Softmax.cpp b/tools/pnnx/src/pass_ncnn/nn_Softmax.cpp new file mode 100644 index 000000000000..64bf4e5554e1 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Softmax.cpp @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Softmax : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Softmax op_0 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Softmax"; + } + + const char* name_str() const + { + return "softmax"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int axis = captured_params.at("dim").i; + op->params["0"] = axis > 0 ? axis - 1 : axis; + op->params["1"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Softmax, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Tanh.cpp b/tools/pnnx/src/pass_ncnn/nn_Tanh.cpp new file mode 100644 index 000000000000..f2cc3db3e115 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Tanh.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Tanh : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Tanh op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "TanH"; + } + + const char* name_str() const + { + return "tanh"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Tanh, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_Upsample.cpp b/tools/pnnx/src/pass_ncnn/nn_Upsample.cpp new file mode 100644 index 000000000000..854a2c290a88 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_Upsample.cpp @@ -0,0 +1,269 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_Upsample : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Upsample op_0 1 1 input out mode=%mode scale_factor=%scale_factor size=%size align_corners=%align_corners +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + const std::vector& scale_factor = captured_params.at("scale_factor").af; + const std::vector& size = captured_params.at("size").ai; + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (scale_factor.size() == 1) + { + op->params["1"] = 1.f; + op->params["2"] = scale_factor[0]; + } + else if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else if (size.size() == 1) + { + op->params["3"] = 1; + op->params["4"] = size[0]; + } + else if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample scale_factor or size\n"); + } + + op->params["6"] = captured_params.at("align_corners").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Upsample, 20) + +class nn_Upsample_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Upsample op_0 1 1 input out mode=%mode size=%size align_corners=%align_corners +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + const std::vector& size = captured_params.at("size").ai; + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (size.size() == 1) + { + op->params["3"] = 1; + op->params["4"] = size[0]; + } + else if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample size\n"); + } + + op->params["6"] = captured_params.at("align_corners").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Upsample_1, 20) + +class nn_Upsample_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Upsample op_0 1 1 input out mode=%mode scale_factor=%scale_factor size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + const std::vector& scale_factor = captured_params.at("scale_factor").af; + const std::vector& size = captured_params.at("size").ai; + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (scale_factor.size() == 1) + { + op->params["1"] = 1.f; + op->params["2"] = scale_factor[0]; + } + else if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else if (size.size() == 1) + { + op->params["3"] = 1; + op->params["4"] = size[0]; + } + else if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample scale_factor or size\n"); + } + + op->params["6"] = 0; // align_corner + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Upsample_2, 20) + +class nn_Upsample_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Upsample op_0 1 1 input out mode=%mode size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsample"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + const std::vector& size = captured_params.at("size").ai; + + if (mode == "nearest") + op->params["0"] = 1; + if (mode == "bilinear" || mode == "linear") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + if (size.size() == 1) + { + op->params["3"] = 1; + op->params["4"] = size[0]; + } + else if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample size\n"); + } + + op->params["6"] = 0; // align_corner + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Upsample_3, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_UpsamplingBilinear2d.cpp b/tools/pnnx/src/pass_ncnn/nn_UpsamplingBilinear2d.cpp new file mode 100644 index 000000000000..8803cfa0f27b --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_UpsamplingBilinear2d.cpp @@ -0,0 +1,119 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_UpsamplingBilinear2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.UpsamplingBilinear2d op_0 1 1 input out scale_factor=%scale_factor size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsamplebilinear2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& scale_factor = captured_params.at("scale_factor").af; + const std::vector& size = captured_params.at("size").ai; + + op->params["0"] = 2; + + if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample scale_factor or size\n"); + } + + op->params["6"] = 1; // align_corner + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_UpsamplingBilinear2d, 20) + +class nn_UpsamplingBilinear2d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.UpsamplingBilinear2d op_0 1 1 input out size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsamplebilinear2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& size = captured_params.at("size").ai; + + op->params["0"] = 2; + + if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample size\n"); + } + + op->params["6"] = 1; // align_corner + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_UpsamplingBilinear2d_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_UpsamplingNearest2d.cpp b/tools/pnnx/src/pass_ncnn/nn_UpsamplingNearest2d.cpp new file mode 100644 index 000000000000..cc235a446a5c --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_UpsamplingNearest2d.cpp @@ -0,0 +1,115 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_UpsamplingNearest2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.UpsamplingNearest2d op_0 1 1 input out scale_factor=%scale_factor size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsamplenearest2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& scale_factor = captured_params.at("scale_factor").af; + const std::vector& size = captured_params.at("size").ai; + + op->params["0"] = 1; + + if (scale_factor.size() == 2) + { + op->params["1"] = scale_factor[0]; + op->params["2"] = scale_factor[1]; + } + else if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample scale_factor or size\n"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_UpsamplingNearest2d, 20) + +class nn_UpsamplingNearest2d_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.UpsamplingNearest2d op_0 1 1 input out size=%size +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Interp"; + } + + const char* name_str() const + { + return "upsamplenearest2d"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& size = captured_params.at("size").ai; + + op->params["0"] = 1; + + if (size.size() == 2) + { + op->params["3"] = size[0]; + op->params["4"] = size[1]; + } + else + { + fprintf(stderr, "unsupported upsample size\n"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_UpsamplingNearest2d_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_ZeroPad2d.cpp b/tools/pnnx/src/pass_ncnn/nn_ZeroPad2d.cpp new file mode 100644 index 000000000000..26fd52fcd191 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_ZeroPad2d.cpp @@ -0,0 +1,59 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_ZeroPad2d : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.ZeroPad2d op_0 1 1 input out padding=%padding +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Padding"; + } + + const char* name_str() const + { + return "zeropad"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("padding").ai[2]; + op->params["1"] = captured_params.at("padding").ai[3]; + op->params["2"] = captured_params.at("padding").ai[0]; + op->params["3"] = captured_params.at("padding").ai[1]; + op->params["4"] = 0; + op->params["5"] = 0.f; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_ZeroPad2d, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp new file mode 100644 index 000000000000..601e84e98f4b --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp @@ -0,0 +1,155 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "solve_batch_index.h" + +#include + +namespace pnnx { + +namespace ncnn { + +void solve_batch_index_backward(Operand* operand); +void solve_batch_index_forward(Operand* operand) +{ + int batch_index = operand->params["__batch_index"].i; + + for (Operator* op : operand->consumers) + { + if (op->type == "torch.permute" || op->type == "Tensor.permute") + { + const std::vector& dims = op->params.at("dims").ai; + + int batch_index_permuted = -1; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + { + batch_index_permuted = i; + break; + } + } + + for (Operand* r : op->outputs) + { + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index_permuted; + + solve_batch_index_forward(r); + solve_batch_index_backward(r); + } + } + else if (op->type == "nn.RNN" || op->type == "nn.LSTM" || op->type == "nn.GRU") + { + { + Operand* r = op->outputs[0]; + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index; + + solve_batch_index_forward(r); + solve_batch_index_backward(r); + } + + for (size_t i = 1; i < op->outputs.size(); i++) + { + Operand* r = op->outputs[i]; + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = 1; + + solve_batch_index_forward(r); + solve_batch_index_backward(r); + } + } + else + { + for (Operand* r : op->outputs) + { + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index; + + solve_batch_index_forward(r); + solve_batch_index_backward(r); + } + } + } +} + +void solve_batch_index_backward(Operand* operand) +{ + int batch_index = operand->params["__batch_index"].i; + + Operator* op = operand->producer; + + if (op->type == "torch.permute") + { + const std::vector& dims = op->params.at("dims").ai; + + int batch_index_permuted = dims[batch_index]; + + for (Operand* r : op->inputs) + { + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index_permuted; + + solve_batch_index_backward(r); + solve_batch_index_forward(r); + } + } + else + { + for (Operand* r : op->inputs) + { + if (r->params.find("__batch_index") != r->params.end()) + continue; + + r->params["__batch_index"] = batch_index; + + solve_batch_index_backward(r); + solve_batch_index_forward(r); + } + } +} + +void solve_batch_index(Graph& graph) +{ + // assign input and ongoing + for (int i = 0; i < (int)graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "pnnx.Input") + continue; + + for (Operand* r : op->outputs) + { + r->params["__batch_index"] = 0; + + solve_batch_index_forward(r); + } + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/solve_batch_index.h b/tools/pnnx/src/pass_ncnn/solve_batch_index.h new file mode 100644 index 000000000000..645ca4757ecf --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/solve_batch_index.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void solve_batch_index(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_clamp.cpp b/tools/pnnx/src/pass_ncnn/torch_clamp.cpp new file mode 100644 index 000000000000..745e9dfc147a --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_clamp.cpp @@ -0,0 +1,49 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_clamp : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.clamp op_0 1 1 input out min=%0 max=%1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Clip"; + } + + const char* name_str() const + { + return "clamp"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_clamp, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_flatten.cpp b/tools/pnnx/src/pass_ncnn/torch_flatten.cpp new file mode 100644 index 000000000000..e943de38eb3e --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_flatten.cpp @@ -0,0 +1,89 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_flatten : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.flatten op_0 1 1 input out start_dim=1 end_dim=-1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Flatten"; + } + + const char* name_str() const + { + return "flatten"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_flatten, 20) + +class torch_flatten_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.flatten op_0 1 1 input out start_dim=2 end_dim=-1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reshape"; + } + + const char* name_str() const + { + return "flatten"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + int input_rank = op->inputs[0]->shape.size(); + + if (input_rank <= 2) + { + fprintf(stderr, "flatten 2 to -1 not possible for %d-rank tensor\n", input_rank); + return; + } + + op->params["0"] = -1; + op->params["1"] = op->inputs[0]->shape[1]; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_flatten_2, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_mean.cpp b/tools/pnnx/src/pass_ncnn/torch_mean.cpp new file mode 100644 index 000000000000..e0eecce29938 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_mean.cpp @@ -0,0 +1,104 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_mean : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.mean op_0 1 1 input out dim=(2,3) keepdim=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Pooling"; + } + + const char* name_str() const + { + return "gap"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 1; + op->params["4"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean, 20) + +class torch_mean_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.mean op_0 1 1 input out dim=%dim keepdim=%keepdim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "mean"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& dims = captured_params.at("dim").ai; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + // drop mean batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + op->params["0"] = 3; + op->params["1"] = 0; + op->params["3"] = new_dims; + op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_permute.cpp b/tools/pnnx/src/pass_ncnn/torch_permute.cpp new file mode 100644 index 000000000000..488dc4bbf9d4 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_permute.cpp @@ -0,0 +1,149 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_permute : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.permute op_0 1 1 input out dims=%dims +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Permute"; + } + + const char* name_str() const + { + return "permute"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 0; + + const std::vector& dims = captured_params.at("dims").ai; + + int input_rank = (int)op->inputs[0]->shape.size(); + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (input_rank > 5) + { + fprintf(stderr, "permute %d-rank tensor is not supported yet!\n", input_rank); + return; + } + + if (input_rank == 0) + { + // assume input is fine + input_rank = (int)dims.size(); + } + + if (input_rank != (int)dims.size()) + { + fprintf(stderr, "permute %d-rank tensor with %d-rank dims is not possible\n", input_rank, (int)dims.size()); + return; + } + + // drop permute batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + if (input_rank == 2) + { + // noop + } + if (input_rank == 3) + { + if (new_dims == std::vector{0, 1}) + op->type = "Noop"; + else if (new_dims == std::vector{1, 0}) + op->params["0"] = 1; + } + if (input_rank == 4) + { + if (new_dims == std::vector{0, 1, 2}) + op->type = "Noop"; + else if (new_dims == std::vector{0, 2, 1}) + op->params["0"] = 1; + else if (new_dims == std::vector{1, 0, 2}) + op->params["0"] = 2; + else if (new_dims == std::vector{1, 2, 0}) + op->params["0"] = 3; + else if (new_dims == std::vector{2, 0, 1}) + op->params["0"] = 4; + else if (new_dims == std::vector{2, 1, 0}) + op->params["0"] = 5; + } + if (input_rank == 5) + { + if (new_dims == std::vector{0, 1, 2, 3}) + op->type = "Noop"; + else if (new_dims == std::vector{0, 2, 3, 1}) + op->params["0"] = 1; + else if (new_dims == std::vector{1, 0, 2, 3}) + op->params["0"] = 2; + else if (new_dims == std::vector{1, 2, 3, 0}) + op->params["0"] = 3; + else if (new_dims == std::vector{2, 3, 0, 1}) + op->params["0"] = 4; + else if (new_dims == std::vector{2, 3, 1, 0}) + op->params["0"] = 5; + else + fprintf(stderr, "unsupported permute dims!\n"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_permute, 20) + +class Tensor_permute : public torch_permute +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Tensor.permute op_0 1 1 input out dims=%dims +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_permute, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_squeeze.cpp b/tools/pnnx/src/pass_ncnn/torch_squeeze.cpp new file mode 100644 index 000000000000..1b475bbd7555 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_squeeze.cpp @@ -0,0 +1,108 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_squeeze : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.squeeze op_0 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Squeeze"; + } + + const char* name_str() const + { + return "squeeze"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + int dim = captured_params.at("dim").i; + if (dim == batch_index) + { + fprintf(stderr, "squeeze batch dim %d is not supported yet!\n", batch_index); + return; + } + + int input_rank = op->inputs[0]->shape.size(); + + if (input_rank > 4) + { + fprintf(stderr, "squeeze %d-rank tensor is not supported yet!\n", input_rank); + return; + } + + if (dim > batch_index) + dim -= 1; + + std::vector axes = {dim}; + op->params["3"] = axes; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_squeeze, 20) + +class torch_squeeze_0 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.squeeze op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Squeeze"; + } + + const char* name_str() const + { + return "squeeze"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["0"] = 1; + op->params["1"] = 1; + op->params["2"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_squeeze_0, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_transpose.cpp b/tools/pnnx/src/pass_ncnn/torch_transpose.cpp new file mode 100644 index 000000000000..897c414e8feb --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_transpose.cpp @@ -0,0 +1,116 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_transpose : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.transpose op_0 1 1 input out dim0=%dim0 dim1=%dim1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Permute"; + } + + const char* name_str() const + { + return "transpose"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 0; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + int dim0 = captured_params.at("dim0").i; + int dim1 = captured_params.at("dim1").i; + if (dim0 == batch_index || dim1 == batch_index) + { + fprintf(stderr, "permute across batch dim is not supported yet!\n"); + return; + } + + int input_rank = op->inputs[0]->shape.size(); + + if (input_rank > 5) + { + fprintf(stderr, "permute %d-rank tensor is not supported yet!\n", input_rank); + return; + } + + if (dim0 > batch_index) + dim0 -= 1; + if (dim1 > batch_index) + dim1 -= 1; + + if (input_rank == 1) + { + fprintf(stderr, "permute across one-rank tensor is not supported yet!\n"); + // should never reach here + } + if (input_rank == 2) + { + // noop + } + if (input_rank == 3) + { + if (dim0 == 0 && dim1 == 1) op->params["0"] = 1; + if (dim0 == 1 && dim1 == 0) op->params["0"] = 1; + } + if (input_rank == 4) + { + if (dim0 == 0 && dim1 == 1) op->params["0"] = 2; + if (dim0 == 1 && dim1 == 0) op->params["0"] = 2; + if (dim0 == 0 && dim1 == 2) op->params["0"] = 5; + if (dim0 == 2 && dim1 == 0) op->params["0"] = 5; + if (dim0 == 1 && dim1 == 2) op->params["0"] = 1; + if (dim0 == 2 && dim1 == 1) op->params["0"] = 1; + } + if (input_rank == 5) + { + if (dim0 == 3 || dim1 == 3) + { + fprintf(stderr, "permute across 5-rank tensor is not supported yet!\n"); + return; + } + + if (dim0 == 0 && dim1 == 1) op->params["0"] = 2; + if (dim0 == 1 && dim1 == 0) op->params["0"] = 2; + if (dim0 == 0 && dim1 == 2) op->params["0"] = 5; + if (dim0 == 2 && dim1 == 0) op->params["0"] = 5; + if (dim0 == 1 && dim1 == 2) op->params["0"] = 1; + if (dim0 == 2 && dim1 == 1) op->params["0"] = 1; + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_transpose, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_unsqueeze.cpp b/tools/pnnx/src/pass_ncnn/torch_unsqueeze.cpp new file mode 100644 index 000000000000..3dc2084d8f15 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_unsqueeze.cpp @@ -0,0 +1,75 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_unsqueeze : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.unsqueeze op_0 1 1 input out dim=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "ExpandDims"; + } + + const char* name_str() const + { + return "unsqueeze"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + int dim = captured_params.at("dim").i; + if (dim == batch_index) + { + fprintf(stderr, "unsqueeze batch dim %d is not supported yet!\n", batch_index); + return; + } + + int input_rank = op->inputs[0]->shape.size(); + + if (input_rank > 3) + { + fprintf(stderr, "unsqueeze %d-rank tensor is not supported yet!\n", input_rank); + return; + } + + if (dim > batch_index) + dim -= 1; + + std::vector axes = {dim}; + op->params["3"] = axes; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_unsqueeze, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/storezip.cpp b/tools/pnnx/src/storezip.cpp new file mode 100644 index 000000000000..8722c591b362 --- /dev/null +++ b/tools/pnnx/src/storezip.cpp @@ -0,0 +1,395 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "storezip.h" + +#include +#include +#include +#include +#include + +namespace pnnx { + +// https://stackoverflow.com/questions/1537964/visual-c-equivalent-of-gccs-attribute-packed +#ifdef _MSC_VER +#define PACK(__Declaration__) __pragma(pack(push, 1)) __Declaration__ __pragma(pack(pop)) +#else +#define PACK(__Declaration__) __Declaration__ __attribute__((__packed__)) +#endif + +PACK(struct local_file_header { + uint16_t version; + uint16_t flag; + uint16_t compression; + uint16_t last_modify_time; + uint16_t last_modify_date; + uint32_t crc32; + uint32_t compressed_size; + uint32_t uncompressed_size; + uint16_t file_name_length; + uint16_t extra_field_length; +}); + +PACK(struct central_directory_file_header { + uint16_t version_made; + uint16_t version; + uint16_t flag; + uint16_t compression; + uint16_t last_modify_time; + uint16_t last_modify_date; + uint32_t crc32; + uint32_t compressed_size; + uint32_t uncompressed_size; + uint16_t file_name_length; + uint16_t extra_field_length; + uint16_t file_comment_length; + uint16_t start_disk; + uint16_t internal_file_attrs; + uint32_t external_file_attrs; + uint32_t lfh_offset; +}); + +PACK(struct end_of_central_directory_record { + uint16_t disk_number; + uint16_t start_disk; + uint16_t cd_records; + uint16_t total_cd_records; + uint32_t cd_size; + uint32_t cd_offset; + uint16_t comment_length; +}); + +static uint32_t CRC32_TABLE[256]; + +static void CRC32_TABLE_INIT() +{ + for (int i = 0; i < 256; i++) + { + uint32_t c = i; + for (int j = 0; j < 8; j++) + { + if (c & 1) + c = (c >> 1) ^ 0xedb88320; + else + c >>= 1; + } + CRC32_TABLE[i] = c; + } +} + +static uint32_t CRC32(uint32_t x, unsigned char ch) +{ + return (x >> 8) ^ CRC32_TABLE[(x ^ ch) & 0xff]; +} + +static uint32_t CRC32_buffer(const unsigned char* data, int len) +{ + uint32_t x = 0xffffffff; + + for (int i = 0; i < len; i++) + x = CRC32(x, data[i]); + + return x ^ 0xffffffff; +} + +StoreZipReader::StoreZipReader() +{ + fp = 0; +} + +StoreZipReader::~StoreZipReader() +{ + close(); +} + +int StoreZipReader::open(const std::string& path) +{ + close(); + + fp = fopen(path.c_str(), "rb"); + if (!fp) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + while (!feof(fp)) + { + // peek signature + uint32_t signature; + int nread = fread((char*)&signature, sizeof(signature), 1, fp); + if (nread != 1) + break; + + if (signature == 0x04034b50) + { + local_file_header lfh; + fread((char*)&lfh, sizeof(lfh), 1, fp); + + if (lfh.flag & 0x08) + { + fprintf(stderr, "zip file contains data descriptor, this is not supported yet\n"); + return -1; + } + + if (lfh.compression != 0 || lfh.compressed_size != lfh.uncompressed_size) + { + fprintf(stderr, "not stored zip file %d %d\n", lfh.compressed_size, lfh.uncompressed_size); + return -1; + } + + // file name + std::string name; + name.resize(lfh.file_name_length); + fread((char*)name.data(), name.size(), 1, fp); + + // skip extra field + fseek(fp, lfh.extra_field_length, SEEK_CUR); + + StoreZipMeta fm; + fm.offset = ftell(fp); + fm.size = lfh.compressed_size; + + filemetas[name] = fm; + + // fprintf(stderr, "%s = %d %d\n", name.c_str(), fm.offset, fm.size); + + fseek(fp, lfh.compressed_size, SEEK_CUR); + } + else if (signature == 0x02014b50) + { + central_directory_file_header cdfh; + fread((char*)&cdfh, sizeof(cdfh), 1, fp); + + // skip file name + fseek(fp, cdfh.file_name_length, SEEK_CUR); + + // skip extra field + fseek(fp, cdfh.extra_field_length, SEEK_CUR); + + // skip file comment + fseek(fp, cdfh.file_comment_length, SEEK_CUR); + } + else if (signature == 0x06054b50) + { + end_of_central_directory_record eocdr; + fread((char*)&eocdr, sizeof(eocdr), 1, fp); + + // skip comment + fseek(fp, eocdr.comment_length, SEEK_CUR); + } + else + { + fprintf(stderr, "unsupported signature %x\n", signature); + return -1; + } + } + + return 0; +} + +size_t StoreZipReader::get_file_size(const std::string& name) +{ + if (filemetas.find(name) == filemetas.end()) + { + fprintf(stderr, "no such file %s\n", name.c_str()); + return 0; + } + + return filemetas[name].size; +} + +int StoreZipReader::read_file(const std::string& name, char* data) +{ + if (filemetas.find(name) == filemetas.end()) + { + fprintf(stderr, "no such file %s\n", name.c_str()); + return -1; + } + + size_t offset = filemetas[name].offset; + size_t size = filemetas[name].size; + + fseek(fp, offset, SEEK_SET); + fread(data, size, 1, fp); + + return 0; +} + +int StoreZipReader::close() +{ + if (!fp) + return 0; + + fclose(fp); + fp = 0; + + return 0; +} + +StoreZipWriter::StoreZipWriter() +{ + fp = 0; + + CRC32_TABLE_INIT(); +} + +StoreZipWriter::~StoreZipWriter() +{ + close(); +} + +int StoreZipWriter::open(const std::string& path) +{ + close(); + + fp = fopen(path.c_str(), "wb"); + if (!fp) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + return 0; +} + +int StoreZipWriter::write_file(const std::string& name, const char* data, size_t size) +{ + int offset = ftell(fp); + + uint32_t signature = 0x04034b50; + fwrite((char*)&signature, sizeof(signature), 1, fp); + + uint32_t crc32 = CRC32_buffer((const unsigned char*)data, size); + + local_file_header lfh; + lfh.version = 0; + lfh.flag = 0; + lfh.compression = 0; + lfh.last_modify_time = 0; + lfh.last_modify_date = 0; + lfh.crc32 = crc32; + lfh.compressed_size = size; + lfh.uncompressed_size = size; + lfh.file_name_length = name.size(); + lfh.extra_field_length = 0; + + fwrite((char*)&lfh, sizeof(lfh), 1, fp); + + fwrite((char*)name.c_str(), name.size(), 1, fp); + + fwrite(data, size, 1, fp); + + StoreZipMeta szm; + szm.name = name; + szm.lfh_offset = offset; + szm.crc32 = crc32; + szm.size = size; + + filemetas.push_back(szm); + + return 0; +} + +int StoreZipWriter::close() +{ + if (!fp) + return 0; + + int offset = ftell(fp); + + for (const StoreZipMeta& szm : filemetas) + { + uint32_t signature = 0x02014b50; + fwrite((char*)&signature, sizeof(signature), 1, fp); + + central_directory_file_header cdfh; + cdfh.version_made = 0; + cdfh.version = 0; + cdfh.flag = 0; + cdfh.compression = 0; + cdfh.last_modify_time = 0; + cdfh.last_modify_date = 0; + cdfh.crc32 = szm.crc32; + cdfh.compressed_size = szm.size; + cdfh.uncompressed_size = szm.size; + cdfh.file_name_length = szm.name.size(); + cdfh.extra_field_length = 0; + cdfh.file_comment_length = 0; + cdfh.start_disk = 0; + cdfh.internal_file_attrs = 0; + cdfh.external_file_attrs = 0; + cdfh.lfh_offset = szm.lfh_offset; + + fwrite((char*)&cdfh, sizeof(cdfh), 1, fp); + + fwrite((char*)szm.name.c_str(), szm.name.size(), 1, fp); + } + + int offset2 = ftell(fp); + + { + uint32_t signature = 0x06054b50; + fwrite((char*)&signature, sizeof(signature), 1, fp); + + end_of_central_directory_record eocdr; + eocdr.disk_number = 0; + eocdr.start_disk = 0; + eocdr.cd_records = filemetas.size(); + eocdr.total_cd_records = filemetas.size(); + eocdr.cd_size = offset2 - offset; + eocdr.cd_offset = offset; + eocdr.comment_length = 0; + + fwrite((char*)&eocdr, sizeof(eocdr), 1, fp); + } + + fclose(fp); + fp = 0; + + return 0; +} + +} // namespace pnnx + +#if 0 +int main() +{ + StoreZipReader sz; + + sz.open("test.zip"); + + std::vector data1; + sz.read_file("pnnx2.py", data1); + + std::vector data2; + sz.read_file("pnnx2.param", data2); + + sz.close(); + + + StoreZipWriter szw; + + szw.open("szw.zip"); + + szw.write_file("a.py", data1); + szw.write_file("zzzz.param", data2); + + szw.close(); + + + return 0; +} +#endif diff --git a/tools/pnnx/src/storezip.h b/tools/pnnx/src/storezip.h new file mode 100644 index 000000000000..2f6e17b2a0a7 --- /dev/null +++ b/tools/pnnx/src/storezip.h @@ -0,0 +1,78 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_STOREZIP_H +#define PNNX_STOREZIP_H + +#include +#include +#include + +namespace pnnx { + +class StoreZipReader +{ +public: + StoreZipReader(); + ~StoreZipReader(); + + int open(const std::string& path); + + size_t get_file_size(const std::string& name); + + int read_file(const std::string& name, char* data); + + int close(); + +private: + FILE* fp; + + struct StoreZipMeta + { + size_t offset; + size_t size; + }; + + std::map filemetas; +}; + +class StoreZipWriter +{ +public: + StoreZipWriter(); + ~StoreZipWriter(); + + int open(const std::string& path); + + int write_file(const std::string& name, const char* data, size_t size); + + int close(); + +private: + FILE* fp; + + struct StoreZipMeta + { + std::string name; + size_t lfh_offset; + uint32_t crc32; + uint32_t size; + }; + + std::vector filemetas; +}; + +} // namespace pnnx + +#endif // PNNX_STOREZIP_H diff --git a/tools/pnnx/src/utils.cpp b/tools/pnnx/src/utils.cpp new file mode 100644 index 000000000000..cbd07dcb4b6d --- /dev/null +++ b/tools/pnnx/src/utils.cpp @@ -0,0 +1,30 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "utils.h" + +namespace pnnx { + +const torch::jit::Node* find_node_by_kind(const std::shared_ptr& graph, const std::string& kind) +{ + for (const auto& n : graph->nodes()) + { + if (n->kind().toDisplayString() == kind) + return n; + } + + return 0; +} + +} // namespace pnnx diff --git a/tools/pnnx/src/utils.h b/tools/pnnx/src/utils.h new file mode 100644 index 000000000000..8eba09e7f282 --- /dev/null +++ b/tools/pnnx/src/utils.h @@ -0,0 +1,27 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_UTILS_H +#define PNNX_UTILS_H + +#include +#include + +namespace pnnx { + +const torch::jit::Node* find_node_by_kind(const std::shared_ptr& graph, const std::string& kind); + +} // namespace pnnx + +#endif // PNNX_UTILS_H diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt new file mode 100644 index 000000000000..237354d0f17b --- /dev/null +++ b/tools/pnnx/tests/CMakeLists.txt @@ -0,0 +1,183 @@ + +find_package(Python3 REQUIRED COMPONENTS Interpreter) + +macro(pnnx_add_test name) + add_test(NAME test_${name} COMMAND ${CMAKE_COMMAND} -DPYTHON_EXECUTABLE=${Python3_EXECUTABLE} -DPYTHON_SCRIPT=${CMAKE_CURRENT_SOURCE_DIR}/test_${name}.py -P ${CMAKE_CURRENT_SOURCE_DIR}/run_test.cmake) +endmacro() + +pnnx_add_test(F_adaptive_avg_pool1d) +pnnx_add_test(F_adaptive_avg_pool2d) +pnnx_add_test(F_adaptive_avg_pool3d) +pnnx_add_test(F_adaptive_max_pool1d) +pnnx_add_test(F_adaptive_max_pool2d) +pnnx_add_test(F_adaptive_max_pool3d) +pnnx_add_test(F_affine_grid) +pnnx_add_test(F_avg_pool1d) +pnnx_add_test(F_avg_pool2d) +pnnx_add_test(F_avg_pool3d) +pnnx_add_test(F_batch_norm) +pnnx_add_test(F_celu) +pnnx_add_test(F_conv1d) +pnnx_add_test(F_conv2d) +pnnx_add_test(F_conv3d) +pnnx_add_test(F_conv_transpose1d) +pnnx_add_test(F_conv_transpose2d) +pnnx_add_test(F_conv_transpose3d) +pnnx_add_test(F_elu) +pnnx_add_test(F_gelu) +pnnx_add_test(F_grid_sample) +pnnx_add_test(F_group_norm) +pnnx_add_test(F_hardshrink) +pnnx_add_test(F_hardsigmoid) +pnnx_add_test(F_hardswish) +pnnx_add_test(F_hardtanh) +pnnx_add_test(F_instance_norm) +pnnx_add_test(F_interpolate) +pnnx_add_test(F_layer_norm) +pnnx_add_test(F_leaky_relu) +pnnx_add_test(F_linear) +pnnx_add_test(F_local_response_norm) +pnnx_add_test(F_log_softmax) +pnnx_add_test(F_logsigmoid) +pnnx_add_test(F_lp_pool1d) +pnnx_add_test(F_lp_pool2d) +pnnx_add_test(F_max_pool1d) +pnnx_add_test(F_max_pool2d) +pnnx_add_test(F_max_pool3d) +pnnx_add_test(F_normalize) +pnnx_add_test(F_pad) +pnnx_add_test(F_pixel_shuffle) +pnnx_add_test(F_pixel_unshuffle) +pnnx_add_test(F_prelu) +pnnx_add_test(F_relu) +pnnx_add_test(F_relu6) +pnnx_add_test(F_rrelu) +pnnx_add_test(F_selu) +pnnx_add_test(F_sigmoid) +pnnx_add_test(F_silu) +pnnx_add_test(F_softmax) +pnnx_add_test(F_softmin) +pnnx_add_test(F_softplus) +pnnx_add_test(F_softshrink) +pnnx_add_test(F_softsign) +pnnx_add_test(F_tanh) +pnnx_add_test(F_tanhshrink) +pnnx_add_test(F_threshold) +pnnx_add_test(F_upsample_bilinear) +pnnx_add_test(F_upsample_nearest) +pnnx_add_test(F_upsample) + +pnnx_add_test(nn_AdaptiveAvgPool1d) +pnnx_add_test(nn_AdaptiveAvgPool2d) +pnnx_add_test(nn_AdaptiveAvgPool3d) +pnnx_add_test(nn_AdaptiveMaxPool1d) +pnnx_add_test(nn_AdaptiveMaxPool2d) +pnnx_add_test(nn_AdaptiveMaxPool3d) +pnnx_add_test(nn_AvgPool1d) +pnnx_add_test(nn_AvgPool2d) +pnnx_add_test(nn_AvgPool3d) +pnnx_add_test(nn_BatchNorm1d) +pnnx_add_test(nn_BatchNorm2d) +pnnx_add_test(nn_BatchNorm3d) +pnnx_add_test(nn_CELU) +pnnx_add_test(nn_ChannelShuffle) +pnnx_add_test(nn_ConstantPad1d) +pnnx_add_test(nn_ConstantPad2d) +pnnx_add_test(nn_ConstantPad3d) +pnnx_add_test(nn_Conv1d) +pnnx_add_test(nn_Conv2d) +pnnx_add_test(nn_Conv3d) +pnnx_add_test(nn_ConvTranspose1d) +pnnx_add_test(nn_ConvTranspose2d) +pnnx_add_test(nn_ConvTranspose3d) +pnnx_add_test(nn_ELU) +pnnx_add_test(nn_GELU) +pnnx_add_test(nn_GroupNorm) +pnnx_add_test(nn_GRU) +pnnx_add_test(nn_Hardshrink) +pnnx_add_test(nn_Hardsigmoid) +pnnx_add_test(nn_Hardswish) +pnnx_add_test(nn_Hardtanh) +pnnx_add_test(nn_InstanceNorm1d) +pnnx_add_test(nn_InstanceNorm2d) +pnnx_add_test(nn_InstanceNorm3d) +pnnx_add_test(nn_LayerNorm) +pnnx_add_test(nn_LeakyReLU) +pnnx_add_test(nn_Linear) +pnnx_add_test(nn_LocalResponseNorm) +pnnx_add_test(nn_LogSigmoid) +pnnx_add_test(nn_LogSoftmax) +pnnx_add_test(nn_LPPool1d) +pnnx_add_test(nn_LPPool2d) +pnnx_add_test(nn_LSTM) +pnnx_add_test(nn_MaxPool1d) +pnnx_add_test(nn_MaxPool2d) +pnnx_add_test(nn_MaxPool3d) +pnnx_add_test(nn_MultiheadAttention) +pnnx_add_test(nn_PixelShuffle) +pnnx_add_test(nn_PixelUnshuffle) +pnnx_add_test(nn_PReLU) +pnnx_add_test(nn_ReflectionPad1d) +pnnx_add_test(nn_ReflectionPad2d) +pnnx_add_test(nn_ReLU) +pnnx_add_test(nn_ReLU6) +pnnx_add_test(nn_ReplicationPad1d) +pnnx_add_test(nn_ReplicationPad2d) +pnnx_add_test(nn_ReplicationPad3d) +pnnx_add_test(nn_RNN) +pnnx_add_test(nn_RReLU) +pnnx_add_test(nn_SELU) +pnnx_add_test(nn_Sigmoid) +pnnx_add_test(nn_SiLU) +pnnx_add_test(nn_Softmax) +pnnx_add_test(nn_Softmin) +pnnx_add_test(nn_Softplus) +pnnx_add_test(nn_Softshrink) +pnnx_add_test(nn_Softsign) +pnnx_add_test(nn_Tanh) +pnnx_add_test(nn_Tanhshrink) +pnnx_add_test(nn_Threshold) +pnnx_add_test(nn_Upsample) +pnnx_add_test(nn_UpsamplingBilinear2d) +pnnx_add_test(nn_UpsamplingNearest2d) +pnnx_add_test(nn_ZeroPad2d) + +pnnx_add_test(Tensor_contiguous) +pnnx_add_test(Tensor_new_empty) +pnnx_add_test(Tensor_repeat) +pnnx_add_test(Tensor_reshape) +pnnx_add_test(Tensor_select) +pnnx_add_test(Tensor_slice) +pnnx_add_test(Tensor_view) + +pnnx_add_test(torch_cat) +pnnx_add_test(torch_chunk) +pnnx_add_test(torch_clamp) +pnnx_add_test(torch_flatten) +pnnx_add_test(torch_mean) +pnnx_add_test(torch_permute) +pnnx_add_test(torch_sum) +pnnx_add_test(torch_split) +pnnx_add_test(torch_squeeze) +pnnx_add_test(torch_transpose) +pnnx_add_test(torch_unsqueeze) + +pnnx_add_test(mobilenet_v2) +pnnx_add_test(mobilenet_v3_small) +pnnx_add_test(resnet18) +pnnx_add_test(shufflenet_v2_x1_0) +pnnx_add_test(squeezenet1_1) + +# TODO enable end2end quantization model test +#pnnx_add_test(quantization_shufflenet_v2_x1_0) + +pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) +pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d) +pnnx_add_test(pnnx_fuse_linear_batchnorm1d) + +if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") + pnnx_add_test(F_mish) + pnnx_add_test(nn_Mish) +endif() + +add_subdirectory(ncnn) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt new file mode 100644 index 000000000000..85a45dcfc04f --- /dev/null +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -0,0 +1,93 @@ + +find_package(Python3 REQUIRED COMPONENTS Interpreter) + +macro(pnnx_ncnn_add_test name) + add_test(NAME test_ncnn_${name} COMMAND ${CMAKE_COMMAND} -DPYTHON_EXECUTABLE=${Python3_EXECUTABLE} -DPYTHON_SCRIPT=${CMAKE_CURRENT_SOURCE_DIR}/test_${name}.py -P ${CMAKE_CURRENT_SOURCE_DIR}/../run_test.cmake) +endmacro() + +pnnx_ncnn_add_test(F_elu) +pnnx_ncnn_add_test(F_gelu) +pnnx_ncnn_add_test(F_hardsigmoid) +pnnx_ncnn_add_test(F_hardswish) +pnnx_ncnn_add_test(F_hardtanh) +pnnx_ncnn_add_test(F_interpolate) +pnnx_ncnn_add_test(F_leaky_relu) +pnnx_ncnn_add_test(F_local_response_norm) +pnnx_ncnn_add_test(F_normalize) +pnnx_ncnn_add_test(F_pad) +pnnx_ncnn_add_test(F_pixel_shuffle) +pnnx_ncnn_add_test(F_pixel_unshuffle) +pnnx_ncnn_add_test(F_relu) +pnnx_ncnn_add_test(F_relu6) +pnnx_ncnn_add_test(F_sigmoid) +pnnx_ncnn_add_test(F_silu) +pnnx_ncnn_add_test(F_softmax) +pnnx_ncnn_add_test(F_tanh) +pnnx_ncnn_add_test(F_upsample_bilinear) +pnnx_ncnn_add_test(F_upsample_nearest) +pnnx_ncnn_add_test(F_upsample) + +pnnx_ncnn_add_test(nn_AdaptiveAvgPool2d) +pnnx_ncnn_add_test(nn_AvgPool2d) +pnnx_ncnn_add_test(nn_BatchNorm1d) +pnnx_ncnn_add_test(nn_BatchNorm2d) +pnnx_ncnn_add_test(nn_ChannelShuffle) +pnnx_ncnn_add_test(nn_ConstantPad1d) +pnnx_ncnn_add_test(nn_ConstantPad2d) +pnnx_ncnn_add_test(nn_Conv1d) +pnnx_ncnn_add_test(nn_Conv2d) +pnnx_ncnn_add_test(nn_ConvTranspose2d) +pnnx_ncnn_add_test(nn_ELU) +pnnx_ncnn_add_test(nn_GELU) +pnnx_ncnn_add_test(nn_GroupNorm) +pnnx_ncnn_add_test(nn_GRU) +pnnx_ncnn_add_test(nn_Hardsigmoid) +pnnx_ncnn_add_test(nn_Hardswish) +pnnx_ncnn_add_test(nn_Hardtanh) +pnnx_ncnn_add_test(nn_InstanceNorm2d) +pnnx_ncnn_add_test(nn_LayerNorm) +pnnx_ncnn_add_test(nn_LeakyReLU) +pnnx_ncnn_add_test(nn_Linear) +pnnx_ncnn_add_test(nn_LocalResponseNorm) +pnnx_ncnn_add_test(nn_LSTM) +pnnx_ncnn_add_test(nn_MaxPool2d) +pnnx_ncnn_add_test(nn_MultiheadAttention) +pnnx_ncnn_add_test(nn_PixelShuffle) +pnnx_ncnn_add_test(nn_PixelUnshuffle) +pnnx_ncnn_add_test(nn_PReLU) +pnnx_ncnn_add_test(nn_ReflectionPad1d) +pnnx_ncnn_add_test(nn_ReflectionPad2d) +pnnx_ncnn_add_test(nn_ReLU) +pnnx_ncnn_add_test(nn_ReLU6) +pnnx_ncnn_add_test(nn_ReplicationPad1d) +pnnx_ncnn_add_test(nn_ReplicationPad2d) +pnnx_ncnn_add_test(nn_RNN) +pnnx_ncnn_add_test(nn_SELU) +pnnx_ncnn_add_test(nn_Sigmoid) +pnnx_ncnn_add_test(nn_SiLU) +pnnx_ncnn_add_test(nn_Softmax) +pnnx_ncnn_add_test(nn_Tanh) +pnnx_ncnn_add_test(nn_Upsample) +pnnx_ncnn_add_test(nn_UpsamplingBilinear2d) +pnnx_ncnn_add_test(nn_UpsamplingNearest2d) +pnnx_ncnn_add_test(nn_ZeroPad2d) + +pnnx_ncnn_add_test(Tensor_reshape) +pnnx_ncnn_add_test(Tensor_slice) +pnnx_ncnn_add_test(Tensor_view) + +pnnx_ncnn_add_test(torch_permute) +pnnx_ncnn_add_test(torch_squeeze) +pnnx_ncnn_add_test(torch_transpose) +pnnx_ncnn_add_test(torch_unsqueeze) + +pnnx_ncnn_add_test(mobilenet_v2) +pnnx_ncnn_add_test(mobilenet_v3_small) +pnnx_ncnn_add_test(resnet18) +pnnx_ncnn_add_test(shufflenet_v2_x1_0) +pnnx_ncnn_add_test(squeezenet1_1) + +if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") + pnnx_ncnn_add_test(F_mish) + pnnx_ncnn_add_test(nn_Mish) +endif() diff --git a/tools/pnnx/tests/ncnn/test_F_elu.py b/tools/pnnx/tests/ncnn/test_F_elu.py new file mode 100644 index 000000000000..ea32eff96e77 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_elu.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.elu(x) + y = F.elu(y, 1.2) + z = F.elu(z, -0.6) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_elu.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_elu.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_elu_ncnn + b = test_F_elu_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_gelu.py b/tools/pnnx/tests/ncnn/test_F_gelu.py new file mode 100644 index 000000000000..d063bbafcfd1 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_gelu.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.gelu(x) + y = F.gelu(y) + z = F.gelu(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_gelu.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_gelu.pt inputshape=[1,16],[12,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_gelu_ncnn + b = test_F_gelu_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_hardsigmoid.py b/tools/pnnx/tests/ncnn/test_F_hardsigmoid.py new file mode 100644 index 000000000000..91b88716e0f7 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_hardsigmoid.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def hardsigmoid_forward_0(x): + return F.relu6(x + 3., True) / 6. + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.hardsigmoid(x) + y = F.hardsigmoid(y) + z = hardsigmoid_forward_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_hardsigmoid.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_hardsigmoid.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_hardsigmoid_ncnn + b = test_F_hardsigmoid_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_hardswish.py b/tools/pnnx/tests/ncnn/test_F_hardswish.py new file mode 100644 index 000000000000..27c2ad49e53a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_hardswish.py @@ -0,0 +1,72 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def hardswish_forward_0(x): + return x * F.hardsigmoid(x) + +def hardswish_forward_1(x): + return x * F.hardtanh(x + 3, 0., 6.) / 6. + +def hardswish_forward_2(x): + out = F.relu6(x + 3., True) / 6. + return out * x + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.hardswish(x) + y = hardswish_forward_0(y) + z = hardswish_forward_1(z) + z = hardswish_forward_2(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_hardswish.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_hardswish.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_hardswish_ncnn + b = test_F_hardswish_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_hardtanh.py b/tools/pnnx/tests/ncnn/test_F_hardtanh.py new file mode 100644 index 000000000000..ee67ede532c1 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_hardtanh.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.hardtanh(x) + y = F.hardtanh(y, -1, 1) + z = F.hardtanh(z, -0.1, 0.1) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_hardtanh.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_hardtanh.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_hardtanh_ncnn + b = test_F_hardtanh_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_interpolate.py b/tools/pnnx/tests/ncnn/test_F_interpolate.py new file mode 100644 index 000000000000..3f0a429743e0 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_interpolate.py @@ -0,0 +1,95 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = F.interpolate(x, size=16) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = F.interpolate(x, size=(20), mode='nearest') + x = F.interpolate(x, scale_factor=(4), mode='nearest') + x = F.interpolate(x, size=16, mode='linear') + x = F.interpolate(x, scale_factor=2, mode='linear') + x = F.interpolate(x, size=(24), mode='linear', align_corners=True) + x = F.interpolate(x, scale_factor=(3), mode='linear', align_corners=True) + + x = F.interpolate(x, scale_factor=1.5, mode='nearest', recompute_scale_factor=True) + x = F.interpolate(x, scale_factor=1.2, mode='linear', align_corners=False, recompute_scale_factor=True) + x = F.interpolate(x, scale_factor=0.8, mode='linear', align_corners=True, recompute_scale_factor=True) + + y = F.interpolate(y, size=16) + y = F.interpolate(y, scale_factor=2, mode='nearest') + y = F.interpolate(y, size=(20,20), mode='nearest') + y = F.interpolate(y, scale_factor=(4,4), mode='nearest') + y = F.interpolate(y, size=(16,24), mode='nearest') + y = F.interpolate(y, scale_factor=(2,3), mode='nearest') + y = F.interpolate(y, size=16, mode='bilinear') + y = F.interpolate(y, scale_factor=2, mode='bilinear') + y = F.interpolate(y, size=(20,20), mode='bilinear', align_corners=False) + y = F.interpolate(y, scale_factor=(4,4), mode='bilinear', align_corners=False) + y = F.interpolate(y, size=(16,24), mode='bilinear', align_corners=True) + y = F.interpolate(y, scale_factor=(2,3), mode='bilinear', align_corners=True) + y = F.interpolate(y, size=16, mode='bicubic') + y = F.interpolate(y, scale_factor=2, mode='bicubic') + y = F.interpolate(y, size=(20,20), mode='bicubic', align_corners=False) + y = F.interpolate(y, scale_factor=(4,4), mode='bicubic', align_corners=False) + y = F.interpolate(y, size=(16,24), mode='bicubic', align_corners=True) + y = F.interpolate(y, scale_factor=(2,3), mode='bicubic', align_corners=True) + + y = F.interpolate(y, scale_factor=(1.7,2), mode='nearest', recompute_scale_factor=True) + y = F.interpolate(y, scale_factor=(2,1.2), mode='bilinear', align_corners=False, recompute_scale_factor=True) + y = F.interpolate(y, scale_factor=(0.5,0.4), mode='bilinear', align_corners=True, recompute_scale_factor=True) + y = F.interpolate(y, scale_factor=(0.8,0.9), mode='bicubic', align_corners=False, recompute_scale_factor=True) + y = F.interpolate(y, scale_factor=(1.1,0.5), mode='bicubic', align_corners=True, recompute_scale_factor=True) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32) + y = torch.rand(1, 3, 32, 32) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_F_interpolate.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_interpolate.pt inputshape=[1,3,32],[1,3,32,32]") + + # ncnn inference + import test_F_interpolate_ncnn + b = test_F_interpolate_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_leaky_relu.py b/tools/pnnx/tests/ncnn/test_F_leaky_relu.py new file mode 100644 index 000000000000..1ab092f900c7 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_leaky_relu.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.leaky_relu(x) + y = F.leaky_relu(y, 0.1) + z = F.leaky_relu(z, -0.22) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_leaky_relu.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_leaky_relu.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_leaky_relu_ncnn + b = test_F_leaky_relu_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_local_response_norm.py b/tools/pnnx/tests/ncnn/test_F_local_response_norm.py new file mode 100644 index 000000000000..9c92eb9c7882 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_local_response_norm.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.local_response_norm(x, 4) + x = F.local_response_norm(x, size=4, alpha=0.001, beta=0.2, k=1.9) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 12, 16) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_local_response_norm.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_local_response_norm.pt inputshape=[1,3,12,16]") + + # ncnn inference + import test_F_local_response_norm_ncnn + b = test_F_local_response_norm_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_mish.py b/tools/pnnx/tests/ncnn/test_F_mish.py new file mode 100644 index 000000000000..42b060a7fe38 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_mish.py @@ -0,0 +1,67 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def mish_forward_0(x): + return x * F.softplus(x).tanh() + +def mish_forward_1(x): + return x.mul(torch.tanh(F.softplus(x))) + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.mish(x) + y = mish_forward_0(y) + z = mish_forward_1(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_mish.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_mish.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_mish_ncnn + b = test_F_mish_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_normalize.py b/tools/pnnx/tests/ncnn/test_F_normalize.py new file mode 100644 index 000000000000..c04117de0b64 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_normalize.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.normalize(x) + x = F.normalize(x, eps=1e-3) + + # TODO + #y = F.normalize(y, p=1, dim=1) + #y = F.normalize(y, dim=2) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_normalize.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_normalize.pt inputshape=[1,12,24,64]") + + # ncnn inference + import test_F_normalize_ncnn + b = test_F_normalize_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_pad.py b/tools/pnnx/tests/ncnn/test_F_pad.py new file mode 100644 index 000000000000..88590649883c --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_pad.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.pad(x, (3,4), mode='constant', value=1.3) + x = F.pad(x, (2,2)) + + y = F.pad(y, (5,6), mode='reflect') + y = F.pad(y, (2,1), mode='replicate') + y = F.pad(y, (3,4), mode='constant', value=1.3) + y = F.pad(y, (1,1)) + + z = F.pad(z, (3,4,3,4), mode='reflect') + z = F.pad(z, (2,1,2,0), mode='replicate') + z = F.pad(z, (1,0,2,0), mode='constant', value=1.3) + z = F.pad(z, (3,3,3,3)) + + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_pad.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_pad.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_pad_ncnn + b = test_F_pad_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_pixel_shuffle.py b/tools/pnnx/tests/ncnn/test_F_pixel_shuffle.py new file mode 100644 index 000000000000..6149b843c58d --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_pixel_shuffle.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.pixel_shuffle(x, 2) + x = F.pixel_shuffle(x, 4) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 6, 7) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_pixel_shuffle.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_pixel_shuffle.pt inputshape=[1,128,6,7]") + + # ncnn inference + import test_F_pixel_shuffle_ncnn + b = test_F_pixel_shuffle_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_pixel_unshuffle.py b/tools/pnnx/tests/ncnn/test_F_pixel_unshuffle.py new file mode 100644 index 000000000000..2652ba998cc7 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_pixel_unshuffle.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.pixel_unshuffle(x, 4) + x = F.pixel_unshuffle(x, 2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 128, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_pixel_unshuffle.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_pixel_unshuffle.pt inputshape=[1,3,128,128]") + + # ncnn inference + import test_F_pixel_unshuffle_ncnn + b = test_F_pixel_unshuffle_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_relu.py b/tools/pnnx/tests/ncnn/test_F_relu.py new file mode 100644 index 000000000000..f08dcf005444 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_relu.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.relu(x) + y = F.relu(y) + z = F.relu(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_relu.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_relu.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_relu_ncnn + b = test_F_relu_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_relu6.py b/tools/pnnx/tests/ncnn/test_F_relu6.py new file mode 100644 index 000000000000..3cc43350465d --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_relu6.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.relu6(x) + y = F.relu6(y) + z = F.relu6(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_relu6.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_relu6.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_relu6_ncnn + b = test_F_relu6_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_sigmoid.py b/tools/pnnx/tests/ncnn/test_F_sigmoid.py new file mode 100644 index 000000000000..fab556d11194 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_sigmoid.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.sigmoid(x) + y = F.sigmoid(y) + z = F.sigmoid(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_sigmoid.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_sigmoid.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_sigmoid_ncnn + b = test_F_sigmoid_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_silu.py b/tools/pnnx/tests/ncnn/test_F_silu.py new file mode 100644 index 000000000000..e02b66a6e425 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_silu.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def silu_forward_0(x): + return x * torch.sigmoid(x) + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.silu(x) + y = F.silu(y) + z = silu_forward_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_silu.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_silu.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_silu_ncnn + b = test_F_silu_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_softmax.py b/tools/pnnx/tests/ncnn/test_F_softmax.py new file mode 100644 index 000000000000..abf38dab0f0b --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_softmax.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.softmax(x, 1) + y = F.softmax(y, 1) + z = F.softmax(z, 2) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_softmax.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_softmax.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_softmax_ncnn + b = test_F_softmax_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_tanh.py b/tools/pnnx/tests/ncnn/test_F_tanh.py new file mode 100644 index 000000000000..e3fb8c434d3e --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_tanh.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.tanh(x) + y = F.tanh(y) + z = F.tanh(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 3, 12, 16) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_tanh.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_tanh.pt inputshape=[1,16],[1,2,16],[1,3,12,16]") + + # ncnn inference + import test_F_tanh_ncnn + b = test_F_tanh_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_upsample.py b/tools/pnnx/tests/ncnn/test_F_upsample.py new file mode 100644 index 000000000000..f453524bd67b --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_upsample.py @@ -0,0 +1,85 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = F.upsample(x, size=16) + x = F.upsample(x, scale_factor=2, mode='nearest') + x = F.upsample(x, size=(20), mode='nearest') + x = F.upsample(x, scale_factor=(4), mode='nearest') + x = F.upsample(x, size=16, mode='linear') + x = F.upsample(x, scale_factor=2, mode='linear') + x = F.upsample(x, size=(24), mode='linear', align_corners=True) + x = F.upsample(x, scale_factor=(3), mode='linear', align_corners=True) + + y = F.upsample(y, size=16) + y = F.upsample(y, scale_factor=2, mode='nearest') + y = F.upsample(y, size=(20,20), mode='nearest') + y = F.upsample(y, scale_factor=(4,4), mode='nearest') + y = F.upsample(y, size=(16,24), mode='nearest') + y = F.upsample(y, scale_factor=(2,3), mode='nearest') + y = F.upsample(y, size=16, mode='bilinear') + y = F.upsample(y, scale_factor=2, mode='bilinear') + y = F.upsample(y, size=(20,20), mode='bilinear', align_corners=False) + y = F.upsample(y, scale_factor=(4,4), mode='bilinear', align_corners=False) + y = F.upsample(y, size=(16,24), mode='bilinear', align_corners=True) + y = F.upsample(y, scale_factor=(2,3), mode='bilinear', align_corners=True) + y = F.upsample(y, size=16, mode='bicubic') + y = F.upsample(y, scale_factor=2, mode='bicubic') + y = F.upsample(y, size=(20,20), mode='bicubic', align_corners=False) + y = F.upsample(y, scale_factor=(4,4), mode='bicubic', align_corners=False) + y = F.upsample(y, size=(16,24), mode='bicubic', align_corners=True) + y = F.upsample(y, scale_factor=(2,3), mode='bicubic', align_corners=True) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32) + y = torch.rand(1, 3, 32, 32) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_F_upsample.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_upsample.pt inputshape=[1,3,32],[1,3,32,32]") + + # ncnn inference + import test_F_upsample_ncnn + b = test_F_upsample_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_upsample_bilinear.py b/tools/pnnx/tests/ncnn/test_F_upsample_bilinear.py new file mode 100644 index 000000000000..178794ecaa78 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_upsample_bilinear.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.upsample_bilinear(x, size=(12,12)) + x = F.upsample_bilinear(x, scale_factor=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_upsample_bilinear.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_upsample_bilinear.pt inputshape=[1,12,24,64]") + + # ncnn inference + import test_F_upsample_bilinear_ncnn + b = test_F_upsample_bilinear_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_F_upsample_nearest.py b/tools/pnnx/tests/ncnn/test_F_upsample_nearest.py new file mode 100644 index 000000000000..edcdc8d7427c --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_upsample_nearest.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.upsample_nearest(x, size=(12,12)) + x = F.upsample_nearest(x, scale_factor=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_upsample_nearest.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_upsample_nearest.pt inputshape=[1,12,24,64]") + + # ncnn inference + import test_F_upsample_nearest_ncnn + b = test_F_upsample_nearest_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_Tensor_reshape.py b/tools/pnnx/tests/ncnn/test_Tensor_reshape.py new file mode 100644 index 000000000000..d83e09e27d7d --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_Tensor_reshape.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = x.reshape(1, 2, 24) + x = x.reshape(1, 48) + y = y.reshape(1, 11, 5, 9) + y = y.reshape(1, 99, 5) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_Tensor_reshape.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_Tensor_reshape.pt inputshape=[1,3,16],[1,5,9,11]") + + # ncnn inference + import test_Tensor_reshape_ncnn + b = test_Tensor_reshape_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_Tensor_slice.py b/tools/pnnx/tests/ncnn/test_Tensor_slice.py new file mode 100644 index 000000000000..f39323ac8f0a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_Tensor_slice.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = x[:,:12,1:14:1] + x = x[...,1:] + x = x[:,:,:x.size(2)-1] + y = y[:,1:,5:,3:] + y = y[:,:,1:13:1,:14] + y = y[:,:y.size(1):,:,:] + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 13, 26) + y = torch.rand(1, 15, 19, 21) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_Tensor_slice.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_Tensor_slice.pt inputshape=[1,13,26],[1,15,19,21]") + + # ncnn inference + import test_Tensor_slice_ncnn + b = test_Tensor_slice_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_Tensor_view.py b/tools/pnnx/tests/ncnn/test_Tensor_view.py new file mode 100644 index 000000000000..839b1ce3baed --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_Tensor_view.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = x.view(1, 2, 24) + x = x.view(1, 48) + y = y.view(1, 11, 5, 9) + y = y.view(1, 99, 5) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_Tensor_view.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_Tensor_view.pt inputshape=[1,3,16],[1,5,9,11]") + + # ncnn inference + import test_Tensor_view_ncnn + b = test_Tensor_view_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_mobilenet_v2.py b/tools/pnnx/tests/ncnn/test_mobilenet_v2.py new file mode 100644 index 000000000000..29158c1be701 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_mobilenet_v2.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.mobilenet_v2() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_mobilenet_v2.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_mobilenet_v2.pt inputshape=[1,3,224,224]") + + # ncnn inference + import test_mobilenet_v2_ncnn + b = test_mobilenet_v2_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_mobilenet_v3_small.py b/tools/pnnx/tests/ncnn/test_mobilenet_v3_small.py new file mode 100644 index 000000000000..ae07ce249f5f --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_mobilenet_v3_small.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.mobilenet_v3_small() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_mobilenet_v3_small.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_mobilenet_v3_small.pt inputshape=[1,3,224,224]") + + # ncnn inference + import test_mobilenet_v3_small_ncnn + b = test_mobilenet_v3_small_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_AdaptiveAvgPool2d.py b/tools/pnnx/tests/ncnn/test_nn_AdaptiveAvgPool2d.py new file mode 100644 index 000000000000..1a434862812b --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_AdaptiveAvgPool2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AdaptiveAvgPool2d(output_size=(7,6)) + self.pool_1 = nn.AdaptiveAvgPool2d(output_size=1) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AdaptiveAvgPool2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_AdaptiveAvgPool2d.pt inputshape=[1,128,13,13]") + + # ncnn inference + import test_nn_AdaptiveAvgPool2d_ncnn + b = test_nn_AdaptiveAvgPool2d_ncnn.test_inference() + + b = b.reshape_as(a) + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_AvgPool2d.py b/tools/pnnx/tests/ncnn/test_nn_AvgPool2d.py new file mode 100644 index 000000000000..609afea93629 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_AvgPool2d.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AvgPool2d(kernel_size=3) + self.pool_1 = nn.AvgPool2d(kernel_size=4, stride=2, padding=2) + self.pool_2 = nn.AvgPool2d(kernel_size=(1,3), stride=1, padding=(0,1), ceil_mode=False, count_include_pad=True) + self.pool_3 = nn.AvgPool2d(kernel_size=(4,5), stride=(1,2), padding=(1,2), ceil_mode=True, count_include_pad=False) + self.pool_4 = nn.AvgPool2d(kernel_size=(5,3), stride=(2,1), padding=1, ceil_mode=False, count_include_pad=True) + self.pool_5 = nn.AvgPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AvgPool2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_AvgPool2d.pt inputshape=[1,12,128,128]") + + # ncnn inference + import test_nn_AvgPool2d_ncnn + b = test_nn_AvgPool2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_BatchNorm1d.py b/tools/pnnx/tests/ncnn/test_nn_BatchNorm1d.py new file mode 100644 index 000000000000..06bb0ead25a6 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_BatchNorm1d.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.bn_0 = nn.BatchNorm1d(num_features=32) + self.bn_1 = nn.BatchNorm1d(num_features=32, eps=1e-1, affine=False) + self.bn_2 = nn.BatchNorm1d(num_features=11, affine=True) + + def forward(self, x, y): + x = self.bn_0(x) + x = self.bn_1(x) + + y = self.bn_2(y) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 32, 64) + y = torch.rand(1, 11, 1) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_BatchNorm1d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_BatchNorm1d.pt inputshape=[1,32,64],[1,11,1]") + + # ncnn inference + import test_nn_BatchNorm1d_ncnn + b0, b1 = test_nn_BatchNorm1d_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_BatchNorm2d.py b/tools/pnnx/tests/ncnn/test_nn_BatchNorm2d.py new file mode 100644 index 000000000000..10e3b1e299b2 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_BatchNorm2d.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.bn_0 = nn.BatchNorm2d(num_features=32) + self.bn_1 = nn.BatchNorm2d(num_features=32, eps=1e-1, affine=False) + self.bn_2 = nn.BatchNorm2d(num_features=11, affine=True) + + def forward(self, x, y): + x = self.bn_0(x) + x = self.bn_1(x) + + y = self.bn_2(y) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 32, 12, 64) + y = torch.rand(1, 11, 1, 1) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_BatchNorm2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_BatchNorm2d.pt inputshape=[1,32,12,64],[1,11,1,1]") + + # ncnn inference + import test_nn_BatchNorm2d_ncnn + b0, b1 = test_nn_BatchNorm2d_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ChannelShuffle.py b/tools/pnnx/tests/ncnn/test_nn_ChannelShuffle.py new file mode 100644 index 000000000000..1c31b82a678a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ChannelShuffle.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.shuffle_0 = nn.ChannelShuffle(2) + self.shuffle_1 = nn.ChannelShuffle(16) + + def forward(self, x, y): + x = self.shuffle_0(x) + x = self.shuffle_1(x) + + y = self.shuffle_0(y) + y = self.shuffle_1(y) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 64, 6, 8) + y = torch.rand(1, 96, 7, 9) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_ChannelShuffle.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ChannelShuffle.pt inputshape=[1,64,6,8],[1,96,7,9]") + + # ncnn inference + import test_nn_ChannelShuffle_ncnn + b0, b1 = test_nn_ChannelShuffle_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ConstantPad1d.py b/tools/pnnx/tests/ncnn/test_nn_ConstantPad1d.py new file mode 100644 index 000000000000..4e30aff4482c --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ConstantPad1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ConstantPad1d(2, 0.1) + self.pad_1 = nn.ConstantPad1d(padding=(3,4), value=-1) + self.pad_2 = nn.ConstantPad1d(padding=(1,0), value=123) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConstantPad1d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ConstantPad1d.pt inputshape=[1,12,13]") + + # ncnn inference + import test_nn_ConstantPad1d_ncnn + b = test_nn_ConstantPad1d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ConstantPad2d.py b/tools/pnnx/tests/ncnn/test_nn_ConstantPad2d.py new file mode 100644 index 000000000000..e8bb36e20187 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ConstantPad2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ConstantPad2d(2, 0.1) + self.pad_1 = nn.ConstantPad2d(padding=(3,4,5,6), value=-2) + self.pad_2 = nn.ConstantPad2d(padding=(1,0,2,0), value=0) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConstantPad2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ConstantPad2d.pt inputshape=[1,12,13,13]") + + # ncnn inference + import test_nn_ConstantPad2d_ncnn + b = test_nn_ConstantPad2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Conv1d.py b/tools/pnnx/tests/ncnn/test_nn_Conv1d.py new file mode 100644 index 000000000000..b398a0268019 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Conv1d.py @@ -0,0 +1,73 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.conv_0 = nn.Conv1d(in_channels=12, out_channels=16, kernel_size=3) + self.conv_1 = nn.Conv1d(in_channels=16, out_channels=20, kernel_size=2, stride=2, padding=2, dilation=1) + self.conv_2 = nn.Conv1d(in_channels=20, out_channels=24, kernel_size=3, stride=1, padding=(4), dilation=1, groups=1, bias=False) + if torch.__version__ < '1.9': + self.conv_3 = nn.Conv1d(in_channels=24, out_channels=28, kernel_size=5, stride=1, padding=0, dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv1d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=2, groups=2, bias=False, padding_mode='zeros') + else: + self.conv_3 = nn.Conv1d(in_channels=24, out_channels=28, kernel_size=5, stride=1, padding='valid', dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv1d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=2, groups=2, bias=False, padding_mode='zeros') + #self.conv_5 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + #self.conv_6 = nn.Conv1d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + + def forward(self, x): + x = self.conv_0(x) + x = self.conv_1(x) + x = self.conv_2(x) + x = self.conv_3(x) + x = self.conv_4(x) + #x = self.conv_5(x) + #x = self.conv_6(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_Conv1d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Conv1d.pt inputshape=[1,12,64]") + + # ncnn inference + import test_nn_Conv1d_ncnn + b = test_nn_Conv1d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Conv2d.py b/tools/pnnx/tests/ncnn/test_nn_Conv2d.py new file mode 100644 index 000000000000..33e83729c4c4 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Conv2d.py @@ -0,0 +1,73 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.conv_0 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3) + self.conv_1 = nn.Conv2d(in_channels=16, out_channels=20, kernel_size=(2,4), stride=(2,1), padding=2, dilation=1) + self.conv_2 = nn.Conv2d(in_channels=20, out_channels=24, kernel_size=(1,3), stride=1, padding=(2,4), dilation=1, groups=1, bias=False) + if torch.__version__ < '1.9': + self.conv_3 = nn.Conv2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=1, padding=0, dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=(1,2), groups=2, bias=False, padding_mode='zeros') + else: + self.conv_3 = nn.Conv2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=1, padding='valid', dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2), groups=2, bias=False, padding_mode='zeros') + #self.conv_5 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + #self.conv_6 = nn.Conv2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + + def forward(self, x): + x = self.conv_0(x) + x = self.conv_1(x) + x = self.conv_2(x) + x = self.conv_3(x) + x = self.conv_4(x) + #x = self.conv_5(x) + #x = self.conv_6(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_Conv2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Conv2d.pt inputshape=[1,12,64,64]") + + # ncnn inference + import test_nn_Conv2d_ncnn + b = test_nn_Conv2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ConvTranspose2d.py b/tools/pnnx/tests/ncnn/test_nn_ConvTranspose2d.py new file mode 100644 index 000000000000..f7ad8f83a726 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ConvTranspose2d.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.deconv_0 = nn.ConvTranspose2d(in_channels=12, out_channels=16, kernel_size=3) + self.deconv_1 = nn.ConvTranspose2d(in_channels=16, out_channels=20, kernel_size=(2,4), stride=(2,1), padding=2, output_padding=0) + self.deconv_2 = nn.ConvTranspose2d(in_channels=20, out_channels=24, kernel_size=(1,3), stride=1, padding=(2,4), output_padding=(0,0), dilation=1, groups=1, bias=False) + self.deconv_3 = nn.ConvTranspose2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=2, padding=0, output_padding=(0,1), dilation=1, groups=4, bias=True) + self.deconv_4 = nn.ConvTranspose2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, output_padding=0, dilation=(1,2), groups=2, bias=False) + self.deconv_5 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, output_padding=1, dilation=1, groups=32, bias=True) + self.deconv_6 = nn.ConvTranspose2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) + self.deconv_7 = nn.ConvTranspose2d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), output_padding=(1,0), dilation=2, groups=1, bias=True) + + def forward(self, x): + x = self.deconv_0(x) + x = self.deconv_1(x) + x = self.deconv_2(x) + x = self.deconv_3(x) + x = self.deconv_4(x) + x = self.deconv_5(x) + x = self.deconv_6(x) + x = self.deconv_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 10, 10) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConvTranspose2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ConvTranspose2d.pt inputshape=[1,12,10,10]") + + # ncnn inference + import test_nn_ConvTranspose2d_ncnn + b = test_nn_ConvTranspose2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ELU.py b/tools/pnnx/tests/ncnn/test_nn_ELU.py new file mode 100644 index 000000000000..cf1975f3cc3a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ELU.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.ELU() + self.act_1 = nn.ELU(alpha=1.3) + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_ELU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_ELU_ncnn + b = test_nn_ELU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_GELU.py b/tools/pnnx/tests/ncnn/test_nn_GELU.py new file mode 100644 index 000000000000..0165e747a0b5 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_GELU.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.GELU() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_GELU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_GELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_GELU_ncnn + b = test_nn_GELU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_GRU.py b/tools/pnnx/tests/ncnn/test_nn_GRU.py new file mode 100644 index 000000000000..73c33d9a5a1e --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_GRU.py @@ -0,0 +1,87 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.gru_0_0 = nn.GRU(input_size=32, hidden_size=16) + self.gru_0_1 = nn.GRU(input_size=16, hidden_size=16, num_layers=3, bias=False) + self.gru_0_2 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.gru_0_3 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.gru_0_4 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + + self.gru_1_0 = nn.GRU(input_size=25, hidden_size=16, batch_first=True) + self.gru_1_1 = nn.GRU(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) + self.gru_1_2 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.gru_1_3 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.gru_1_4 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x = x.permute(1, 0, 2) + + x0, _ = self.gru_0_0(x) + x1, _ = self.gru_0_1(x0) + x2, h0 = self.gru_0_2(x1) + x3, h1 = self.gru_0_3(x1, h0) + x4, _ = self.gru_0_4(x1, h1) + + y0, _ = self.gru_1_0(y) + y1, _ = self.gru_1_1(y0) + y2, h2 = self.gru_1_2(y1) + y3, h3 = self.gru_1_3(y1, h2) + y4, _ = self.gru_1_4(y1, h3) + + x2 = x2.permute(1, 0, 2) + x3 = x3.permute(1, 0, 2) + x4 = x4.permute(1, 0, 2) + + return x2, x3, x4, y2, y3, y4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 10, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_GRU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_GRU.pt inputshape=[1,10,32],[1,12,25]") + + # ncnn inference + import test_nn_GRU_ncnn + b = test_nn_GRU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_GroupNorm.py b/tools/pnnx/tests/ncnn/test_nn_GroupNorm.py new file mode 100644 index 000000000000..71f7d684fc5d --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_GroupNorm.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.gn_0 = nn.GroupNorm(num_groups=4, num_channels=12) + self.gn_1 = nn.GroupNorm(num_groups=12, num_channels=12, eps=1e-2, affine=True) + self.gn_2 = nn.GroupNorm(num_groups=1, num_channels=12, eps=1e-4, affine=True) + + def forward(self, x): + x = self.gn_0(x) + x = self.gn_1(x) + x = self.gn_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a0 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_GroupNorm.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_GroupNorm.pt inputshape=[1,12,24,64]") + + # ncnn inference + import test_nn_GroupNorm_ncnn + b0 = test_nn_GroupNorm_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Hardsigmoid.py b/tools/pnnx/tests/ncnn/test_nn_Hardsigmoid.py new file mode 100644 index 000000000000..07454383c033 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Hardsigmoid.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardsigmoid() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Hardsigmoid.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Hardsigmoid.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_Hardsigmoid_ncnn + b = test_nn_Hardsigmoid_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Hardswish.py b/tools/pnnx/tests/ncnn/test_nn_Hardswish.py new file mode 100644 index 000000000000..069e02d5d74a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Hardswish.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardswish() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Hardswish.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Hardswish.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_Hardswish_ncnn + b = test_nn_Hardswish_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Hardtanh.py b/tools/pnnx/tests/ncnn/test_nn_Hardtanh.py new file mode 100644 index 000000000000..4393c5918cb0 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Hardtanh.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardtanh() + self.act_1 = nn.Hardtanh(-0.2, 0.2) + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Hardtanh.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Hardtanh.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_Hardtanh_ncnn + b = test_nn_Hardtanh_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_InstanceNorm2d.py b/tools/pnnx/tests/ncnn/test_nn_InstanceNorm2d.py new file mode 100644 index 000000000000..386ec62cbe8a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_InstanceNorm2d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.in_0 = nn.InstanceNorm2d(num_features=12, affine=True) + self.in_1 = nn.InstanceNorm2d(num_features=12, eps=1e-2, affine=True) + + def forward(self, x): + x = self.in_0(x) + x = self.in_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_InstanceNorm2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_InstanceNorm2d.pt inputshape=[1,12,24,64]") + + # ncnn inference + import test_nn_InstanceNorm2d_ncnn + b = test_nn_InstanceNorm2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_LSTM.py b/tools/pnnx/tests/ncnn/test_nn_LSTM.py new file mode 100644 index 000000000000..55384e8885db --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_LSTM.py @@ -0,0 +1,87 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16) + self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False) + self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.lstm_0_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.lstm_0_4 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + + self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True) + self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) + self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.lstm_1_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.lstm_1_4 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x = x.permute(1, 0, 2) + + x0, _ = self.lstm_0_0(x) + x1, _ = self.lstm_0_1(x0) + x2, (h0, c0) = self.lstm_0_2(x1) + x3, (h1, c1) = self.lstm_0_3(x1, (h0, c0)) + x4, _ = self.lstm_0_4(x1, (h1, c1)) + + y0, _ = self.lstm_1_0(y) + y1, _ = self.lstm_1_1(y0) + y2, (h2, c2) = self.lstm_1_2(y1) + y3, (h3, c3) = self.lstm_1_3(y1, (h2, c2)) + y4, _ = self.lstm_1_4(y1, (h3, c3)) + + x2 = x2.permute(1, 0, 2) + x3 = x3.permute(1, 0, 2) + x4 = x4.permute(1, 0, 2) + + return x2, x3, x4, y2, y3, y4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 10, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_LSTM.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_LSTM.pt inputshape=[1,10,32],[1,12,25]") + + # ncnn inference + import test_nn_LSTM_ncnn + b = test_nn_LSTM_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py b/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py new file mode 100644 index 000000000000..d411759e917e --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.ln_0 = nn.LayerNorm(64) + self.ln_1 = nn.LayerNorm(normalized_shape=(24,64), eps=1e-2, elementwise_affine=False) + + def forward(self, x, y): + x = self.ln_0(x) + y = self.ln_1(y) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 12, 24, 64) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_LayerNorm.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[1,24,64],[1,12,24,64]") + + # ncnn inference + import test_nn_LayerNorm_ncnn + b0, b1 = test_nn_LayerNorm_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_LeakyReLU.py b/tools/pnnx/tests/ncnn/test_nn_LeakyReLU.py new file mode 100644 index 000000000000..8089f689f4dc --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_LeakyReLU.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LeakyReLU() + self.act_1 = nn.LeakyReLU(negative_slope=-0.24) + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_LeakyReLU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_LeakyReLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_LeakyReLU_ncnn + b = test_nn_LeakyReLU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Linear.py b/tools/pnnx/tests/ncnn/test_nn_Linear.py new file mode 100644 index 000000000000..4a68d662ca5a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Linear.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.linear_0 = nn.Linear(in_features=64, out_features=16, bias=False) + self.linear_1 = nn.Linear(in_features=16, out_features=3, bias=True) + + def forward(self, x, y): + x = self.linear_0(x) + x = self.linear_1(x) + + y = self.linear_0(y) + y = self.linear_1(y) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 64) + y = torch.rand(1, 12, 64) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_Linear.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Linear.pt inputshape=[1,64],[1,12,64]") + + # ncnn inference + import test_nn_Linear_ncnn + b0, b1 = test_nn_Linear_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_LocalResponseNorm.py b/tools/pnnx/tests/ncnn/test_nn_LocalResponseNorm.py new file mode 100644 index 000000000000..ac12886d6138 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_LocalResponseNorm.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.ln_0 = nn.LocalResponseNorm(3) + self.ln_1 = nn.LocalResponseNorm(size=5, alpha=0.001, beta=0.8, k=0.9) + + def forward(self, x): + x = self.ln_0(x) + x = self.ln_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a0 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_LocalResponseNorm.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_LocalResponseNorm.pt inputshape=[1,12,24,64]") + + # ncnn inference + import test_nn_LocalResponseNorm_ncnn + b0 = test_nn_LocalResponseNorm_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py b/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py new file mode 100644 index 000000000000..d424a939ee7a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_MaxPool2d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.MaxPool2d(kernel_size=3) + self.pool_1 = nn.MaxPool2d(kernel_size=4, stride=2, padding=2, dilation=1) + self.pool_2 = nn.MaxPool2d(kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, ceil_mode=False) + self.pool_3 = nn.MaxPool2d(kernel_size=(4,5), stride=(1,2), padding=(1,2), dilation=1, ceil_mode=True) + self.pool_4 = nn.MaxPool2d(kernel_size=(2,3), stride=1, padding=1, dilation=1, ceil_mode=False) + self.pool_5 = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=True) + self.pool_6 = nn.MaxPool2d(kernel_size=(5,4), stride=1, padding=2, dilation=1, ceil_mode=False) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x = self.pool_6(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_MaxPool2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_MaxPool2d.pt inputshape=[1,12,64,64]") + + # ncnn inference + import test_nn_MaxPool2d_ncnn + b = test_nn_MaxPool2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Mish.py b/tools/pnnx/tests/ncnn/test_nn_Mish.py new file mode 100644 index 000000000000..52b994242ae9 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Mish.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Mish() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Mish.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Mish.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_Mish_ncnn + b = test_nn_Mish_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_MultiheadAttention.py b/tools/pnnx/tests/ncnn/test_nn_MultiheadAttention.py new file mode 100644 index 000000000000..76058da2d286 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_MultiheadAttention.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.attention_0_0 = nn.MultiheadAttention(embed_dim=64, num_heads=4) + + if torch.__version__ >= '1.9': + self.attention_1_0 = nn.MultiheadAttention(embed_dim=40, num_heads=4, batch_first=True) + + def forward(self, x, y): + x0, _ = self.attention_0_0(x, x, x) + + if torch.__version__ < '1.9': + return x0 + + y0, _ = self.attention_1_0(y, y, y) + + return x0, y0 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 1, 64) + y = torch.rand(1, 15, 40) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_MultiheadAttention.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_MultiheadAttention.pt inputshape=[1,1,64],[1,15,40]") + + # ncnn inference + import test_nn_MultiheadAttention_ncnn + b = test_nn_MultiheadAttention_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_PReLU.py b/tools/pnnx/tests/ncnn/test_nn_PReLU.py new file mode 100644 index 000000000000..53c9534d5c86 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_PReLU.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.prelu_0 = nn.PReLU(num_parameters=12) + self.prelu_1 = nn.PReLU(num_parameters=1, init=0.12) + + def forward(self, x, y, z): + x = self.prelu_0(x) + x = self.prelu_1(x) + + y = self.prelu_0(y) + y = self.prelu_1(y) + + z = self.prelu_0(z) + z = self.prelu_1(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_PReLU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_PReLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_PReLU_ncnn + b = test_nn_PReLU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_PixelShuffle.py b/tools/pnnx/tests/ncnn/test_nn_PixelShuffle.py new file mode 100644 index 000000000000..446a7933d6bb --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_PixelShuffle.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.up_0 = nn.PixelShuffle(4) + self.up_1 = nn.PixelShuffle(2) + + def forward(self, x): + x = self.up_0(x) + x = self.up_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 6, 8) + + a0 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_PixelShuffle.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_PixelShuffle.pt inputshape=[1,128,6,8]") + + # ncnn inference + import test_nn_PixelShuffle_ncnn + b0 = test_nn_PixelShuffle_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_PixelUnshuffle.py b/tools/pnnx/tests/ncnn/test_nn_PixelUnshuffle.py new file mode 100644 index 000000000000..68db3508c557 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_PixelUnshuffle.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.down_0 = nn.PixelUnshuffle(2) + self.down_1 = nn.PixelUnshuffle(4) + + def forward(self, x): + x = self.down_0(x) + x = self.down_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 128, 128) + + a0 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_PixelUnshuffle.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_PixelUnshuffle.pt inputshape=[1,3,128,128]") + + # ncnn inference + import test_nn_PixelUnshuffle_ncnn + b0 = test_nn_PixelUnshuffle_ncnn.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_RNN.py b/tools/pnnx/tests/ncnn/test_nn_RNN.py new file mode 100644 index 000000000000..bc891b3485a8 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_RNN.py @@ -0,0 +1,87 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.rnn_0_0 = nn.RNN(input_size=32, hidden_size=16) + self.rnn_0_1 = nn.RNN(input_size=16, hidden_size=16, num_layers=3, nonlinearity='tanh', bias=False) + self.rnn_0_2 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, bidirectional=True) + self.rnn_0_3 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, bidirectional=True) + self.rnn_0_4 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, bidirectional=True) + + self.rnn_1_0 = nn.RNN(input_size=25, hidden_size=16, batch_first=True) + self.rnn_1_1 = nn.RNN(input_size=16, hidden_size=16, num_layers=3, nonlinearity='tanh', bias=False, batch_first=True) + self.rnn_1_2 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, batch_first=True, bidirectional=True) + self.rnn_1_3 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, batch_first=True, bidirectional=True) + self.rnn_1_4 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x = x.permute(1, 0, 2) + + x0, _ = self.rnn_0_0(x) + x1, _ = self.rnn_0_1(x0) + x2, h0 = self.rnn_0_2(x1) + x3, h1 = self.rnn_0_3(x1, h0) + x4, _ = self.rnn_0_4(x1, h1) + + y0, _ = self.rnn_1_0(y) + y1, _ = self.rnn_1_1(y0) + y2, h2 = self.rnn_1_2(y1) + y3, h3 = self.rnn_1_3(y1, h2) + y4, _ = self.rnn_1_3(y1, h3) + + x2 = x2.permute(1, 0, 2) + x3 = x3.permute(1, 0, 2) + x4 = x4.permute(1, 0, 2) + + return x2, x3, x4, y2, y3, y4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 10, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_RNN.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_RNN.pt inputshape=[1,10,32],[1,12,25]") + + # ncnn inference + import test_nn_RNN_ncnn + b = test_nn_RNN_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ReLU.py b/tools/pnnx/tests/ncnn/test_nn_ReLU.py new file mode 100644 index 000000000000..10a1b0d60f5d --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ReLU.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.ReLU() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_ReLU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ReLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_ReLU_ncnn + b = test_nn_ReLU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ReLU6.py b/tools/pnnx/tests/ncnn/test_nn_ReLU6.py new file mode 100644 index 000000000000..52fd2dedafb9 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ReLU6.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.ReLU6() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_ReLU6.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ReLU6.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_ReLU6_ncnn + b = test_nn_ReLU6_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ReflectionPad1d.py b/tools/pnnx/tests/ncnn/test_nn_ReflectionPad1d.py new file mode 100644 index 000000000000..39bcf8ce6324 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ReflectionPad1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReflectionPad1d(2) + self.pad_1 = nn.ReflectionPad1d(padding=(3,4)) + self.pad_2 = nn.ReflectionPad1d(padding=(1,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReflectionPad1d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ReflectionPad1d.pt inputshape=[1,12,13]") + + # ncnn inference + import test_nn_ReflectionPad1d_ncnn + b = test_nn_ReflectionPad1d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ReflectionPad2d.py b/tools/pnnx/tests/ncnn/test_nn_ReflectionPad2d.py new file mode 100644 index 000000000000..a2f9e1588084 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ReflectionPad2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReflectionPad2d(2) + self.pad_1 = nn.ReflectionPad2d(padding=(3,4,5,6)) + self.pad_2 = nn.ReflectionPad2d(padding=(1,0,2,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReflectionPad2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ReflectionPad2d.pt inputshape=[1,12,13,13]") + + # ncnn inference + import test_nn_ReflectionPad2d_ncnn + b = test_nn_ReflectionPad2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ReplicationPad1d.py b/tools/pnnx/tests/ncnn/test_nn_ReplicationPad1d.py new file mode 100644 index 000000000000..4ba17a9a8b2b --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ReplicationPad1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReplicationPad1d(2) + self.pad_1 = nn.ReplicationPad1d(padding=(3,4)) + self.pad_2 = nn.ReplicationPad1d(padding=(1,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReplicationPad1d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ReplicationPad1d.pt inputshape=[1,12,13]") + + # ncnn inference + import test_nn_ReplicationPad1d_ncnn + b = test_nn_ReplicationPad1d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ReplicationPad2d.py b/tools/pnnx/tests/ncnn/test_nn_ReplicationPad2d.py new file mode 100644 index 000000000000..5cf663fb2986 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ReplicationPad2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReplicationPad2d(2) + self.pad_1 = nn.ReplicationPad2d(padding=(3,4,5,6)) + self.pad_2 = nn.ReplicationPad2d(padding=(1,0,2,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReplicationPad2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ReplicationPad2d.pt inputshape=[1,12,13,13]") + + # ncnn inference + import test_nn_ReplicationPad2d_ncnn + b = test_nn_ReplicationPad2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_SELU.py b/tools/pnnx/tests/ncnn/test_nn_SELU.py new file mode 100644 index 000000000000..3a08ee31ee5b --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_SELU.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.SELU() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_SELU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_SELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_SELU_ncnn + b = test_nn_SELU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_SiLU.py b/tools/pnnx/tests/ncnn/test_nn_SiLU.py new file mode 100644 index 000000000000..ff58b6523b90 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_SiLU.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.SiLU() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_SiLU.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_SiLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_SiLU_ncnn + b = test_nn_SiLU_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Sigmoid.py b/tools/pnnx/tests/ncnn/test_nn_Sigmoid.py new file mode 100644 index 000000000000..833c214ca899 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Sigmoid.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Sigmoid() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Sigmoid.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Sigmoid.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_Sigmoid_ncnn + b = test_nn_Sigmoid_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Softmax.py b/tools/pnnx/tests/ncnn/test_nn_Softmax.py new file mode 100644 index 000000000000..84d5f74db854 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Softmax.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softmax(dim=1) + self.act_1 = nn.Softmax(dim=1) + self.act_2 = nn.Softmax(dim=2) + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_1(y) + z = self.act_2(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Softmax.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Softmax.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_Softmax_ncnn + b = test_nn_Softmax_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Tanh.py b/tools/pnnx/tests/ncnn/test_nn_Tanh.py new file mode 100644 index 000000000000..6887ed36d7c3 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Tanh.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Tanh() + + def forward(self, x, y, z): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Tanh.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Tanh.pt inputshape=[1,12],[1,12,64],[1,12,24,64]") + + # ncnn inference + import test_nn_Tanh_ncnn + b = test_nn_Tanh_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Upsample.py b/tools/pnnx/tests/ncnn/test_nn_Upsample.py new file mode 100644 index 000000000000..d5ecc13155b5 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Upsample.py @@ -0,0 +1,113 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.up_1d_0_0 = nn.Upsample(size=16) + self.up_1d_0_1 = nn.Upsample(scale_factor=2, mode='nearest') + self.up_1d_0_2 = nn.Upsample(size=(20), mode='nearest') + self.up_1d_0_3 = nn.Upsample(scale_factor=(4), mode='nearest') + self.up_1d_1_0 = nn.Upsample(size=16, mode='linear') + self.up_1d_1_1 = nn.Upsample(scale_factor=2, mode='linear') + self.up_1d_1_2 = nn.Upsample(size=(24), mode='linear', align_corners=True) + self.up_1d_1_3 = nn.Upsample(scale_factor=(3), mode='linear', align_corners=True) + + self.up_2d_0_0 = nn.Upsample(size=16) + self.up_2d_0_1 = nn.Upsample(scale_factor=2, mode='nearest') + self.up_2d_0_2 = nn.Upsample(size=(20,20), mode='nearest') + self.up_2d_0_3 = nn.Upsample(scale_factor=(4,4), mode='nearest') + self.up_2d_0_4 = nn.Upsample(size=(16,24), mode='nearest') + self.up_2d_0_5 = nn.Upsample(scale_factor=(2,3), mode='nearest') + self.up_2d_1_0 = nn.Upsample(size=16, mode='bilinear') + self.up_2d_1_1 = nn.Upsample(scale_factor=2, mode='bilinear') + self.up_2d_1_2 = nn.Upsample(size=(20,20), mode='bilinear', align_corners=False) + self.up_2d_1_3 = nn.Upsample(scale_factor=(4,4), mode='bilinear', align_corners=False) + self.up_2d_1_4 = nn.Upsample(size=(16,24), mode='bilinear', align_corners=True) + self.up_2d_1_5 = nn.Upsample(scale_factor=(2,3), mode='bilinear', align_corners=True) + self.up_2d_2_0 = nn.Upsample(size=16, mode='bicubic') + self.up_2d_2_1 = nn.Upsample(scale_factor=2, mode='bicubic') + self.up_2d_2_2 = nn.Upsample(size=(20,20), mode='bicubic', align_corners=False) + self.up_2d_2_3 = nn.Upsample(scale_factor=(4,4), mode='bicubic', align_corners=False) + self.up_2d_2_4 = nn.Upsample(size=(16,24), mode='bicubic', align_corners=True) + self.up_2d_2_5 = nn.Upsample(scale_factor=(2,3), mode='bicubic', align_corners=True) + + def forward(self, x, y): + x = self.up_1d_0_0(x) + x = self.up_1d_0_1(x) + x = self.up_1d_0_2(x) + x = self.up_1d_0_3(x) + x = self.up_1d_1_0(x) + x = self.up_1d_1_1(x) + x = self.up_1d_1_2(x) + x = self.up_1d_1_3(x) + + y = self.up_2d_0_0(y) + y = self.up_2d_0_1(y) + y = self.up_2d_0_2(y) + y = self.up_2d_0_3(y) + y = self.up_2d_0_4(y) + y = self.up_2d_0_5(y) + y = self.up_2d_1_0(y) + y = self.up_2d_1_1(y) + y = self.up_2d_1_2(y) + y = self.up_2d_1_3(y) + y = self.up_2d_1_4(y) + y = self.up_2d_1_5(y) + y = self.up_2d_2_0(y) + y = self.up_2d_2_1(y) + y = self.up_2d_2_2(y) + y = self.up_2d_2_3(y) + y = self.up_2d_2_4(y) + y = self.up_2d_2_5(y) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32) + y = torch.rand(1, 3, 32, 32) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_Upsample.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Upsample.pt inputshape=[1,3,32],[1,3,32,32]") + + # ncnn inference + import test_nn_Upsample_ncnn + b = test_nn_Upsample_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_UpsamplingBilinear2d.py b/tools/pnnx/tests/ncnn/test_nn_UpsamplingBilinear2d.py new file mode 100644 index 000000000000..704b2c778d36 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_UpsamplingBilinear2d.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.up_0 = nn.UpsamplingBilinear2d(size=16) + self.up_1 = nn.UpsamplingBilinear2d(scale_factor=2) + self.up_2 = nn.UpsamplingBilinear2d(size=(20,20)) + self.up_3 = nn.UpsamplingBilinear2d(scale_factor=(4,4)) + self.up_4 = nn.UpsamplingBilinear2d(size=(16,24)) + self.up_5 = nn.UpsamplingBilinear2d(scale_factor=(2,3)) + + def forward(self, x): + x = self.up_0(x) + x = self.up_1(x) + x = self.up_2(x) + x = self.up_3(x) + x = self.up_4(x) + x = self.up_5(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32, 32) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_UpsamplingBilinear2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_UpsamplingBilinear2d.pt inputshape=[1,3,32,32]") + + # ncnn inference + import test_nn_UpsamplingBilinear2d_ncnn + b = test_nn_UpsamplingBilinear2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_UpsamplingNearest2d.py b/tools/pnnx/tests/ncnn/test_nn_UpsamplingNearest2d.py new file mode 100644 index 000000000000..c729b2a2e9bb --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_UpsamplingNearest2d.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.up_0 = nn.UpsamplingNearest2d(size=16) + self.up_1 = nn.UpsamplingNearest2d(scale_factor=2) + self.up_2 = nn.UpsamplingNearest2d(size=(20,20)) + self.up_3 = nn.UpsamplingNearest2d(scale_factor=(4,4)) + self.up_4 = nn.UpsamplingNearest2d(size=(16,24)) + self.up_5 = nn.UpsamplingNearest2d(scale_factor=(2,3)) + + def forward(self, x): + x = self.up_0(x) + x = self.up_1(x) + x = self.up_2(x) + x = self.up_3(x) + x = self.up_4(x) + x = self.up_5(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32, 32) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_UpsamplingNearest2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_UpsamplingNearest2d.pt inputshape=[1,3,32,32]") + + # ncnn inference + import test_nn_UpsamplingNearest2d_ncnn + b = test_nn_UpsamplingNearest2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_ZeroPad2d.py b/tools/pnnx/tests/ncnn/test_nn_ZeroPad2d.py new file mode 100644 index 000000000000..e8b4cfb171b9 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_ZeroPad2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ZeroPad2d(2) + self.pad_1 = nn.ZeroPad2d(padding=(3,4,5,6)) + self.pad_2 = nn.ZeroPad2d(padding=(1,0,2,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ZeroPad2d.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_ZeroPad2d.pt inputshape=[1,12,13,13]") + + # ncnn inference + import test_nn_ZeroPad2d_ncnn + b = test_nn_ZeroPad2d_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_resnet18.py b/tools/pnnx/tests/ncnn/test_resnet18.py new file mode 100644 index 000000000000..46990baca75e --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_resnet18.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.resnet18() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_resnet18.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_resnet18.pt inputshape=[1,3,224,224]") + + # ncnn inference + import test_resnet18_ncnn + b = test_resnet18_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_shufflenet_v2_x1_0.py b/tools/pnnx/tests/ncnn/test_shufflenet_v2_x1_0.py new file mode 100644 index 000000000000..608c056a0f22 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_shufflenet_v2_x1_0.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.shufflenet_v2_x1_0() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_shufflenet_v2_x1_0.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_shufflenet_v2_x1_0.pt inputshape=[1,3,224,224]") + + # ncnn inference + import test_shufflenet_v2_x1_0_ncnn + b = test_shufflenet_v2_x1_0_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_squeezenet1_1.py b/tools/pnnx/tests/ncnn/test_squeezenet1_1.py new file mode 100644 index 000000000000..665b15eddf6a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_squeezenet1_1.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.squeezenet1_1() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_squeezenet1_1.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_squeezenet1_1.pt inputshape=[1,3,224,224]") + + # ncnn inference + import test_squeezenet1_1_ncnn + b = test_squeezenet1_1_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_permute.py b/tools/pnnx/tests/ncnn/test_torch_permute.py new file mode 100644 index 000000000000..46b65e26d848 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_permute.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + if torch.__version__ < '1.9': + x = x.permute(1, 0, 2) + x = x.permute(0, 2, 1) + x = x.permute(2, 0, 1) + y = y.permute(2, 3, 1, 0) + y = y.permute(3, 1, 0, 2) + else: + x = torch.permute(x, (1, 0, 2)) + x = torch.permute(x, (0, 2, 1)) + x = torch.permute(x, (2, 0, 1)) + y = torch.permute(y, (2, 3, 1, 0)) + y = torch.permute(y, (3, 1, 0, 2)) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_permute.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_permute.pt inputshape=[1,3,16],[1,5,9,11]") + + # ncnn inference + import test_torch_permute_ncnn + b = test_torch_permute_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_squeeze.py b/tools/pnnx/tests/ncnn/test_torch_squeeze.py new file mode 100644 index 000000000000..b4559180a6b3 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_squeeze.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = torch.squeeze(x, 1) + y = torch.squeeze(y, 2) + z = torch.squeeze(z) + w = torch.squeeze(w, 3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 1, 16) + y = torch.rand(1, 3, 1) + z = torch.rand(1, 5, 1, 11) + w = torch.rand(1, 5, 9, 1) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_squeeze.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_squeeze.pt inputshape=[1,1,16],[1,3,1],[1,5,1,11],[1,5,9,1]") + + # ncnn inference + import test_torch_squeeze_ncnn + b = test_torch_squeeze_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + print(a0.shape) + print(b0.shape) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_transpose.py b/tools/pnnx/tests/ncnn/test_torch_transpose.py new file mode 100644 index 000000000000..782f7006e51c --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_transpose.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.transpose(x, 1, 2) + y = torch.transpose(y, 2, 3) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_transpose.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_transpose.pt inputshape=[1,3,16],[1,5,9,11]") + + # ncnn inference + import test_torch_transpose_ncnn + b = test_torch_transpose_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_unsqueeze.py b/tools/pnnx/tests/ncnn/test_torch_unsqueeze.py new file mode 100644 index 000000000000..c8f75a395bd4 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_unsqueeze.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x0 = torch.unsqueeze(x, 1) + x1 = torch.unsqueeze(x, 2) + y0 = torch.unsqueeze(y, 2) + y1 = torch.unsqueeze(y, -1) + return x0, x1, y0, y1 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_unsqueeze.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_unsqueeze.pt inputshape=[1,16],[1,9,11]") + + # ncnn inference + import test_torch_unsqueeze_ncnn + b = test_torch_unsqueeze_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + print(a0.shape) + print(b0.shape) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/run_test.cmake b/tools/pnnx/tests/run_test.cmake new file mode 100644 index 000000000000..dc089db50e5f --- /dev/null +++ b/tools/pnnx/tests/run_test.cmake @@ -0,0 +1,6 @@ + +set(ENV{PYTHONPATH} "ENV{PYTHONPATH}:${CMAKE_CURRENT_BINARY_DIR}") +execute_process(COMMAND ${PYTHON_EXECUTABLE} ${PYTHON_SCRIPT} RESULT_VARIABLE result) +if(NOT "${result}" STREQUAL "0") + message(FATAL_ERROR "Test failed with return value '${result}'") +endif() diff --git a/tools/pnnx/tests/test_F_adaptive_avg_pool1d.py b/tools/pnnx/tests/test_F_adaptive_avg_pool1d.py new file mode 100644 index 000000000000..ba31bc29b3bc --- /dev/null +++ b/tools/pnnx/tests/test_F_adaptive_avg_pool1d.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.adaptive_avg_pool1d(x, output_size=7) + x = F.adaptive_avg_pool1d(x, output_size=1) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_adaptive_avg_pool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_adaptive_avg_pool1d.pt inputshape=[1,12,24]") + + # pnnx inference + import test_F_adaptive_avg_pool1d_pnnx + b = test_F_adaptive_avg_pool1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_adaptive_avg_pool2d.py b/tools/pnnx/tests/test_F_adaptive_avg_pool2d.py new file mode 100644 index 000000000000..3ed400926d9b --- /dev/null +++ b/tools/pnnx/tests/test_F_adaptive_avg_pool2d.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.adaptive_avg_pool2d(x, output_size=(7,6)) + x = F.adaptive_avg_pool2d(x, output_size=1) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_adaptive_avg_pool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_adaptive_avg_pool2d.pt inputshape=[1,12,24,64]") + + # pnnx inference + import test_F_adaptive_avg_pool2d_pnnx + b = test_F_adaptive_avg_pool2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_adaptive_avg_pool3d.py b/tools/pnnx/tests/test_F_adaptive_avg_pool3d.py new file mode 100644 index 000000000000..4d3353ec7f97 --- /dev/null +++ b/tools/pnnx/tests/test_F_adaptive_avg_pool3d.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.adaptive_avg_pool3d(x, output_size=(7,6,5)) + x = F.adaptive_avg_pool3d(x, output_size=1) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 33, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_adaptive_avg_pool3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_adaptive_avg_pool3d.pt inputshape=[1,12,24,33,64]") + + # pnnx inference + import test_F_adaptive_avg_pool3d_pnnx + b = test_F_adaptive_avg_pool3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_adaptive_max_pool1d.py b/tools/pnnx/tests/test_F_adaptive_max_pool1d.py new file mode 100644 index 000000000000..e7c1fe28a7a4 --- /dev/null +++ b/tools/pnnx/tests/test_F_adaptive_max_pool1d.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x, indices = F.adaptive_max_pool1d(x, output_size=7, return_indices=True) + x = F.adaptive_max_pool1d(x, output_size=1) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_adaptive_max_pool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_adaptive_max_pool1d.pt inputshape=[1,12,24]") + + # pnnx inference + import test_F_adaptive_max_pool1d_pnnx + b0, b1 = test_F_adaptive_max_pool1d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_adaptive_max_pool2d.py b/tools/pnnx/tests/test_F_adaptive_max_pool2d.py new file mode 100644 index 000000000000..56bc98dc3f99 --- /dev/null +++ b/tools/pnnx/tests/test_F_adaptive_max_pool2d.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x, indices = F.adaptive_max_pool2d(x, output_size=(7,6), return_indices=True) + x = F.adaptive_max_pool2d(x, output_size=1) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_adaptive_max_pool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_adaptive_max_pool2d.pt inputshape=[1,12,24,64]") + + # pnnx inference + import test_F_adaptive_max_pool2d_pnnx + b0, b1 = test_F_adaptive_max_pool2d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_adaptive_max_pool3d.py b/tools/pnnx/tests/test_F_adaptive_max_pool3d.py new file mode 100644 index 000000000000..34da5d381f98 --- /dev/null +++ b/tools/pnnx/tests/test_F_adaptive_max_pool3d.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x, indices = F.adaptive_max_pool3d(x, output_size=(7,6,5), return_indices=True) + x = F.adaptive_max_pool3d(x, output_size=1) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 33, 64) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_adaptive_max_pool3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_adaptive_max_pool3d.pt inputshape=[1,12,24,33,64]") + + # pnnx inference + import test_F_adaptive_max_pool3d_pnnx + b0, b1 = test_F_adaptive_max_pool3d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_affine_grid.py b/tools/pnnx/tests/test_F_affine_grid.py new file mode 100644 index 000000000000..5ad70e3eb80d --- /dev/null +++ b/tools/pnnx/tests/test_F_affine_grid.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = F.affine_grid(x, torch.Size((32, 3, 24, 24)), align_corners=False) + + y = F.affine_grid(y, torch.Size((12, 3, 10, 20, 30)), align_corners=False) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(32, 2, 3) + y = torch.rand(12, 3, 4) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_F_affine_grid.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_affine_grid.pt inputshape=[32,2,3],[12,3,4]") + + # pnnx inference + import test_F_affine_grid_pnnx + b0, b1 = test_F_affine_grid_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_avg_pool1d.py b/tools/pnnx/tests/test_F_avg_pool1d.py new file mode 100644 index 000000000000..a45e5f2d4450 --- /dev/null +++ b/tools/pnnx/tests/test_F_avg_pool1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.avg_pool1d(x, kernel_size=3) + x = F.avg_pool1d(x, kernel_size=4, stride=2, padding=2) + x = F.avg_pool1d(x, kernel_size=3, stride=1, padding=(0), ceil_mode=False, count_include_pad=True) + x = F.avg_pool1d(x, kernel_size=5, stride=2, padding=(2), ceil_mode=True, count_include_pad=False) + x = F.avg_pool1d(x, kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=True) + x = F.avg_pool1d(x, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + x = F.avg_pool1d(x, kernel_size=4, stride=1, padding=2, ceil_mode=False, count_include_pad=False) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_avg_pool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_avg_pool1d.pt inputshape=[1,12,128]") + + # pnnx inference + import test_F_avg_pool1d_pnnx + b = test_F_avg_pool1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_avg_pool2d.py b/tools/pnnx/tests/test_F_avg_pool2d.py new file mode 100644 index 000000000000..d00b2ecad79e --- /dev/null +++ b/tools/pnnx/tests/test_F_avg_pool2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.avg_pool2d(x, kernel_size=3) + x = F.avg_pool2d(x, kernel_size=4, stride=2, padding=2) + x = F.avg_pool2d(x, kernel_size=(1,3), stride=1, padding=(0,1), ceil_mode=False, count_include_pad=True) + x = F.avg_pool2d(x, kernel_size=(4,5), stride=(1,2), padding=(1,2), ceil_mode=True, count_include_pad=False) + x = F.avg_pool2d(x, kernel_size=(5,3), stride=(2,1), padding=1, ceil_mode=False, count_include_pad=True) + x = F.avg_pool2d(x, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + x = F.avg_pool2d(x, kernel_size=(5,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=18) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128, 127) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_avg_pool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_avg_pool2d.pt inputshape=[1,12,128,127]") + + # pnnx inference + import test_F_avg_pool2d_pnnx + b = test_F_avg_pool2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_avg_pool3d.py b/tools/pnnx/tests/test_F_avg_pool3d.py new file mode 100644 index 000000000000..5867a502dfa0 --- /dev/null +++ b/tools/pnnx/tests/test_F_avg_pool3d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.avg_pool3d(x, kernel_size=3) + x = F.avg_pool3d(x, kernel_size=4, stride=2, padding=2) + x = F.avg_pool3d(x, kernel_size=(1,2,3), stride=1, padding=(0,1,1), ceil_mode=False, count_include_pad=True) + x = F.avg_pool3d(x, kernel_size=(3,4,5), stride=(1,2,2), padding=(1,1,2), ceil_mode=True, count_include_pad=False) + x = F.avg_pool3d(x, kernel_size=(5,4,3), stride=(2,1,1), padding=1, ceil_mode=False, count_include_pad=True) + x = F.avg_pool3d(x, kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + x = F.avg_pool3d(x, kernel_size=(5,4,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=77) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 96, 128, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_avg_pool3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_avg_pool3d.pt inputshape=[1,12,96,128,128]") + + # pnnx inference + import test_F_avg_pool3d_pnnx + b = test_F_avg_pool3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_batch_norm.py b/tools/pnnx/tests/test_F_batch_norm.py new file mode 100644 index 000000000000..35976392220d --- /dev/null +++ b/tools/pnnx/tests/test_F_batch_norm.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, m0, v0, w0, b0, m1, v1, w1, b1, m2, v2, w2, b2): + x = F.batch_norm(x, m0, v0, w0, b0) + x = F.batch_norm(x, m0, v0, None, None) + + y = F.batch_norm(y, m1, v1, w1, b1, eps=1e-3) + y = F.batch_norm(y, m1, v1, None, None) + + z = F.batch_norm(z, m2, v2, w2, b2) + z = F.batch_norm(z, m2, v2, None, None, eps=1e-2) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + m0 = torch.rand(16) + v0 = torch.rand(16) + w0 = torch.rand(16) + b0 = torch.rand(16) + m1 = torch.rand(2) + v1 = torch.rand(2) + w1 = torch.rand(2) + b1 = torch.rand(2) + m2 = torch.rand(3) + v2 = torch.rand(3) + w2 = torch.rand(3) + b2 = torch.rand(3) + + a0, a1, a2 = net(x, y, z, m0, v0, w0, b0, m1, v1, w1, b1, m2, v2, w2, b2) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, m0, v0, w0, b0, m1, v1, w1, b1, m2, v2, w2, b2)) + mod.save("test_F_batch_norm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_batch_norm.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[16],[16],[16],[16],[2],[2],[2],[2],[3],[3],[3],[3]") + + # pnnx inference + import test_F_batch_norm_pnnx + b0, b1, b2 = test_F_batch_norm_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_celu.py b/tools/pnnx/tests/test_F_celu.py new file mode 100644 index 000000000000..49d9e7a8df14 --- /dev/null +++ b/tools/pnnx/tests/test_F_celu.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.celu(x) + y = F.celu(y, 0.8) + z = F.celu(z, -0.5) + w = F.celu(w, 2) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_celu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_celu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_celu_pnnx + b0, b1, b2, b3 = test_F_celu_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_conv1d.py b/tools/pnnx/tests/test_F_conv1d.py new file mode 100644 index 000000000000..f31d89f723c4 --- /dev/null +++ b/tools/pnnx/tests/test_F_conv1d.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, w0, w1, b1): + x = F.conv1d(x, w0, None, stride=2, padding=1) + if torch.__version__ < '1.9': + x = F.conv1d(x, w1, b1, stride=1, padding=1, dilation=2, groups=2) + else: + x = F.conv1d(x, w1, b1, stride=1, padding='same', dilation=2, groups=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 52) + w0 = torch.rand(16, 12, 3) + w1 = torch.rand(16, 8, 5) + b1 = torch.rand(16) + + a = net(x, w0, w1, b1) + + # export torchscript + mod = torch.jit.trace(net, (x, w0, w1, b1)) + mod.save("test_F_conv1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_conv1d.pt inputshape=[1,12,52],[16,12,3],[16,8,5],[16]") + + # pnnx inference + import test_F_conv1d_pnnx + b = test_F_conv1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_conv2d.py b/tools/pnnx/tests/test_F_conv2d.py new file mode 100644 index 000000000000..745eeec128c5 --- /dev/null +++ b/tools/pnnx/tests/test_F_conv2d.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, w0, w1, b1): + x = F.conv2d(x, w0, None, stride=(2,2), padding=(1,1)) + if torch.__version__ < '1.9': + x = F.conv2d(x, w1, b1, stride=(1,1), padding=(1,1), dilation=(2,1), groups=2) + else: + x = F.conv2d(x, w1, b1, stride=(1,1), padding='same', dilation=(2,1), groups=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 52, 64) + w0 = torch.rand(16, 12, 3, 3) + w1 = torch.rand(16, 8, 5, 5) + b1 = torch.rand(16) + + a = net(x, w0, w1, b1) + + # export torchscript + mod = torch.jit.trace(net, (x, w0, w1, b1)) + mod.save("test_F_conv2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_conv2d.pt inputshape=[1,12,52,64],[16,12,3,3],[16,8,5,5],[16]") + + # pnnx inference + import test_F_conv2d_pnnx + b = test_F_conv2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_conv3d.py b/tools/pnnx/tests/test_F_conv3d.py new file mode 100644 index 000000000000..82f2d8669eb4 --- /dev/null +++ b/tools/pnnx/tests/test_F_conv3d.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, w0, w1, b1): + x = F.conv3d(x, w0, None, stride=(2,2,2), padding=(1,0,1)) + if torch.__version__ < '1.9': + x = F.conv3d(x, w1, b1, stride=(1,1,1), padding=(1,1,1), dilation=(2,2,1), groups=2) + else: + x = F.conv3d(x, w1, b1, stride=(1,1,1), padding='same', dilation=(2,2,1), groups=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 20, 32, 40) + w0 = torch.rand(16, 12, 3, 2, 3) + w1 = torch.rand(16, 8, 5, 4, 5) + b1 = torch.rand(16) + + a = net(x, w0, w1, b1) + + # export torchscript + mod = torch.jit.trace(net, (x, w0, w1, b1)) + mod.save("test_F_conv3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_conv3d.pt inputshape=[1,12,20,32,40],[16,12,3,2,3],[16,8,5,4,5],[16]") + + # pnnx inference + import test_F_conv3d_pnnx + b = test_F_conv3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_conv_transpose1d.py b/tools/pnnx/tests/test_F_conv_transpose1d.py new file mode 100644 index 000000000000..b2db491fec45 --- /dev/null +++ b/tools/pnnx/tests/test_F_conv_transpose1d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, w0, w1, b1): + x = F.conv_transpose1d(x, w0, None, stride=2, padding=1, output_padding=1) + x = F.conv_transpose1d(x, w1, b1, stride=1, padding=2, dilation=2, groups=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 22) + w0 = torch.rand(12, 16, 3) + w1 = torch.rand(16, 8, 5) + b1 = torch.rand(16) + + a = net(x, w0, w1, b1) + + # export torchscript + mod = torch.jit.trace(net, (x, w0, w1, b1)) + mod.save("test_F_conv_transpose1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_conv_transpose1d.pt inputshape=[1,12,22],[12,16,3],[16,8,5],[16]") + + # pnnx inference + import test_F_conv_transpose1d_pnnx + b = test_F_conv_transpose1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_conv_transpose2d.py b/tools/pnnx/tests/test_F_conv_transpose2d.py new file mode 100644 index 000000000000..090ef734ece1 --- /dev/null +++ b/tools/pnnx/tests/test_F_conv_transpose2d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, w0, w1, b1): + x = F.conv_transpose2d(x, w0, None, stride=(2,2), padding=(1,1), output_padding=(1,1)) + x = F.conv_transpose2d(x, w1, b1, stride=(1,2), padding=(2,1), dilation=(2,1), groups=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 22, 32) + w0 = torch.rand(12, 16, 3, 3) + w1 = torch.rand(16, 8, 5, 5) + b1 = torch.rand(16) + + a = net(x, w0, w1, b1) + + # export torchscript + mod = torch.jit.trace(net, (x, w0, w1, b1)) + mod.save("test_F_conv_transpose2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_conv_transpose2d.pt inputshape=[1,12,22,32],[12,16,3,3],[16,8,5,5],[16]") + + # pnnx inference + import test_F_conv_transpose2d_pnnx + b = test_F_conv_transpose2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_conv_transpose3d.py b/tools/pnnx/tests/test_F_conv_transpose3d.py new file mode 100644 index 000000000000..24dc2c401dce --- /dev/null +++ b/tools/pnnx/tests/test_F_conv_transpose3d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, w0, w1, b1): + x = F.conv_transpose3d(x, w0, None, stride=(2,2,2), padding=(1,0,1), output_padding=(1,1,0)) + x = F.conv_transpose3d(x, w1, b1, stride=(1,1,2), padding=(2,2,1), dilation=(2,2,1), groups=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 10, 12, 14) + w0 = torch.rand(12, 16, 3, 2, 3) + w1 = torch.rand(16, 8, 5, 4, 5) + b1 = torch.rand(16) + + a = net(x, w0, w1, b1) + + # export torchscript + mod = torch.jit.trace(net, (x, w0, w1, b1)) + mod.save("test_F_conv_transpose3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_conv_transpose3d.pt inputshape=[1,12,10,12,14],[12,16,3,2,3],[16,8,5,4,5],[16]") + + # pnnx inference + import test_F_conv_transpose3d_pnnx + b = test_F_conv_transpose3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_elu.py b/tools/pnnx/tests/test_F_elu.py new file mode 100644 index 000000000000..73047093eb88 --- /dev/null +++ b/tools/pnnx/tests/test_F_elu.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.elu(x) + y = F.elu(y, 1.2) + z = F.elu(z, -0.6) + w = F.elu(w, 0) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_elu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_elu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_elu_pnnx + b0, b1, b2, b3 = test_F_elu_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_gelu.py b/tools/pnnx/tests/test_F_gelu.py new file mode 100644 index 000000000000..a04b540d6d56 --- /dev/null +++ b/tools/pnnx/tests/test_F_gelu.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.gelu(x) + y = F.gelu(y) + z = F.gelu(z) + w = F.gelu(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_gelu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_gelu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_gelu_pnnx + b0, b1, b2, b3 = test_F_gelu_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_grid_sample.py b/tools/pnnx/tests/test_F_grid_sample.py new file mode 100644 index 000000000000..ae4ed354cdf3 --- /dev/null +++ b/tools/pnnx/tests/test_F_grid_sample.py @@ -0,0 +1,90 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, xg1, xg2, y, yg1, yg2): + x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='zeros', align_corners=False) + x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='border', align_corners=False) + x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='reflection', align_corners=False) + x = F.grid_sample(x, xg2, mode='nearest', padding_mode='zeros', align_corners=False) + x = F.grid_sample(x, xg1, mode='nearest', padding_mode='border', align_corners=False) + x = F.grid_sample(x, xg2, mode='nearest', padding_mode='reflection', align_corners=False) + x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='zeros', align_corners=False) + x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='border', align_corners=False) + x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='reflection', align_corners=False) + x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='zeros', align_corners=True) + x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='border', align_corners=True) + x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='reflection', align_corners=True) + x = F.grid_sample(x, xg1, mode='nearest', padding_mode='zeros', align_corners=True) + x = F.grid_sample(x, xg2, mode='nearest', padding_mode='border', align_corners=True) + x = F.grid_sample(x, xg1, mode='nearest', padding_mode='reflection', align_corners=True) + x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='zeros', align_corners=True) + x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='border', align_corners=True) + x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='reflection', align_corners=True) + + y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='zeros', align_corners=False) + y = F.grid_sample(y, yg2, mode='bilinear', padding_mode='border', align_corners=False) + y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='reflection', align_corners=False) + y = F.grid_sample(y, yg2, mode='nearest', padding_mode='zeros', align_corners=False) + y = F.grid_sample(y, yg1, mode='nearest', padding_mode='border', align_corners=False) + y = F.grid_sample(y, yg2, mode='nearest', padding_mode='reflection', align_corners=False) + y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='zeros', align_corners=True) + y = F.grid_sample(y, yg2, mode='bilinear', padding_mode='border', align_corners=True) + y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='reflection', align_corners=True) + y = F.grid_sample(y, yg2, mode='nearest', padding_mode='zeros', align_corners=True) + y = F.grid_sample(y, yg1, mode='nearest', padding_mode='border', align_corners=True) + y = F.grid_sample(y, yg2, mode='nearest', padding_mode='reflection', align_corners=True) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 12, 16) + xg1 = torch.rand(1, 21, 27, 2) + xg2 = torch.rand(1, 12, 16, 2) + y = torch.rand(1, 5, 10, 12, 16) + yg1 = torch.rand(1, 10, 21, 27, 3) + yg2 = torch.rand(1, 10, 12, 16, 3) + + a0, a1 = net(x, xg1, xg2, y, yg1, yg2) + + # export torchscript + mod = torch.jit.trace(net, (x, xg1, xg2, y, yg1, yg2)) + mod.save("test_F_grid_sample.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_grid_sample.pt inputshape=[1,3,12,16],[1,21,27,2],[1,12,16,2],[1,5,10,12,16],[1,10,21,27,3],[1,10,12,16,3]") + + # pnnx inference + import test_F_grid_sample_pnnx + b0, b1 = test_F_grid_sample_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_group_norm.py b/tools/pnnx/tests/test_F_group_norm.py new file mode 100644 index 000000000000..103422a00302 --- /dev/null +++ b/tools/pnnx/tests/test_F_group_norm.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w0, b0, w1, b1, w2, b2): + x = F.group_norm(x, 2, w0, b0) + x = F.group_norm(x, 1, None, None) + + y = F.group_norm(y, 3, w1, b1, eps=1e-4) + y = F.group_norm(y, 4, None, None) + + z = F.group_norm(z, 32, w2, b2) + z = F.group_norm(z, 4, None, None, eps=1e-2) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 12, 16) + z = torch.rand(1, 32, 12, 16) + w0 = torch.rand(16) + b0 = torch.rand(16) + w1 = torch.rand(12) + b1 = torch.rand(12) + w2 = torch.rand(32) + b2 = torch.rand(32) + + a0, a1, a2 = net(x, y, z, w0, b0, w1, b1, w2, b2) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w0, b0, w1, b1, w2, b2)) + mod.save("test_F_group_norm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_group_norm.pt inputshape=[1,16],[12,12,16],[1,32,12,16],[16],[16],[12],[12],[32],[32]") + + # pnnx inference + import test_F_group_norm_pnnx + b0, b1, b2 = test_F_group_norm_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_hardshrink.py b/tools/pnnx/tests/test_F_hardshrink.py new file mode 100644 index 000000000000..3f9cee9ec9cf --- /dev/null +++ b/tools/pnnx/tests/test_F_hardshrink.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.hardshrink(x) + y = F.hardshrink(y, 0.1) + z = F.hardshrink(z, 0.22) + w = F.hardshrink(w, 0) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_hardshrink.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_hardshrink.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_hardshrink_pnnx + b0, b1, b2, b3 = test_F_hardshrink_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_hardsigmoid.py b/tools/pnnx/tests/test_F_hardsigmoid.py new file mode 100644 index 000000000000..5ec3b8d20e7b --- /dev/null +++ b/tools/pnnx/tests/test_F_hardsigmoid.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def hardsigmoid_forward_0(x): + return F.relu6(x + 3., True) / 6. + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.hardsigmoid(x) + y = F.hardsigmoid(y) + z = F.hardsigmoid(z) + w = hardsigmoid_forward_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_hardsigmoid.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_hardsigmoid.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_hardsigmoid_pnnx + b = test_F_hardsigmoid_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_hardswish.py b/tools/pnnx/tests/test_F_hardswish.py new file mode 100644 index 000000000000..50e96b3f32ae --- /dev/null +++ b/tools/pnnx/tests/test_F_hardswish.py @@ -0,0 +1,73 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def hardswish_forward_0(x): + return x * F.hardsigmoid(x) + +def hardswish_forward_1(x): + return x * F.hardtanh(x + 3, 0., 6.) / 6. + +def hardswish_forward_2(x): + out = F.relu6(x + 3., True) / 6. + return out * x + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.hardswish(x) + y = hardswish_forward_0(y) + z = hardswish_forward_1(z) + w = hardswish_forward_2(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_hardswish.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_hardswish.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_hardswish_pnnx + b = test_F_hardswish_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_hardtanh.py b/tools/pnnx/tests/test_F_hardtanh.py new file mode 100644 index 000000000000..54bcba6e122a --- /dev/null +++ b/tools/pnnx/tests/test_F_hardtanh.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.hardtanh(x) + y = F.hardtanh(y, -1, 1) + z = F.hardtanh(z, -0.1, 0.1) + w = F.hardtanh(w, 0.1, 0.3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_hardtanh.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_hardtanh.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_hardtanh_pnnx + b0, b1, b2, b3 = test_F_hardtanh_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_instance_norm.py b/tools/pnnx/tests/test_F_instance_norm.py new file mode 100644 index 000000000000..cd3b0e568fb4 --- /dev/null +++ b/tools/pnnx/tests/test_F_instance_norm.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, m0, v0, w0, b0, m1, v1, w1, b1, m2, v2, w2, b2): + x = F.instance_norm(x, m0, v0, w0, b0) + x = F.instance_norm(x, m0, v0, None, None) + + y = F.instance_norm(y, m1, v1, w1, b1, eps=1e-3) + y = F.instance_norm(y, m1, v1, None, None) + + z = F.instance_norm(z, m2, v2, w2, b2) + z = F.instance_norm(z, m2, v2, None, None, eps=1e-2) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24) + y = torch.rand(2, 3, 12, 16) + z = torch.rand(1, 10, 12, 16, 24) + m0 = torch.rand(12) + v0 = torch.rand(12) + w0 = torch.rand(12) + b0 = torch.rand(12) + m1 = torch.rand(3) + v1 = torch.rand(3) + w1 = torch.rand(3) + b1 = torch.rand(3) + m2 = torch.rand(10) + v2 = torch.rand(10) + w2 = torch.rand(10) + b2 = torch.rand(10) + + a0, a1, a2 = net(x, y, z, m0, v0, w0, b0, m1, v1, w1, b1, m2, v2, w2, b2) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, m0, v0, w0, b0, m1, v1, w1, b1, m2, v2, w2, b2)) + mod.save("test_F_instance_norm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_instance_norm.pt inputshape=[1,12,24],[2,3,12,16],[1,10,12,16,24],[12],[12],[12],[12],[3],[3],[3],[3],[10],[10],[10],[10]") + + # pnnx inference + import test_F_instance_norm_pnnx + b0, b1, b2 = test_F_instance_norm_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_interpolate.py b/tools/pnnx/tests/test_F_interpolate.py new file mode 100644 index 000000000000..12a826f08b14 --- /dev/null +++ b/tools/pnnx/tests/test_F_interpolate.py @@ -0,0 +1,110 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.interpolate(x, size=16) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = F.interpolate(x, size=(20), mode='nearest') + x = F.interpolate(x, scale_factor=(4), mode='nearest') + x = F.interpolate(x, size=16, mode='linear') + x = F.interpolate(x, scale_factor=2, mode='linear') + x = F.interpolate(x, size=(24), mode='linear', align_corners=True) + x = F.interpolate(x, scale_factor=(3), mode='linear', align_corners=True) + + x = F.interpolate(x, scale_factor=1.5, mode='nearest', recompute_scale_factor=True) + x = F.interpolate(x, scale_factor=1.2, mode='linear', align_corners=False, recompute_scale_factor=True) + x = F.interpolate(x, scale_factor=0.8, mode='linear', align_corners=True, recompute_scale_factor=True) + + y = F.interpolate(y, size=16) + y = F.interpolate(y, scale_factor=2, mode='nearest') + y = F.interpolate(y, size=(20,20), mode='nearest') + y = F.interpolate(y, scale_factor=(4,4), mode='nearest') + y = F.interpolate(y, size=(16,24), mode='nearest') + y = F.interpolate(y, scale_factor=(2,3), mode='nearest') + y = F.interpolate(y, size=16, mode='bilinear') + y = F.interpolate(y, scale_factor=2, mode='bilinear') + y = F.interpolate(y, size=(20,20), mode='bilinear', align_corners=False) + y = F.interpolate(y, scale_factor=(4,4), mode='bilinear', align_corners=False) + y = F.interpolate(y, size=(16,24), mode='bilinear', align_corners=True) + y = F.interpolate(y, scale_factor=(2,3), mode='bilinear', align_corners=True) + y = F.interpolate(y, size=16, mode='bicubic') + y = F.interpolate(y, scale_factor=2, mode='bicubic') + y = F.interpolate(y, size=(20,20), mode='bicubic', align_corners=False) + y = F.interpolate(y, scale_factor=(4,4), mode='bicubic', align_corners=False) + y = F.interpolate(y, size=(16,24), mode='bicubic', align_corners=True) + y = F.interpolate(y, scale_factor=(2,3), mode='bicubic', align_corners=True) + + y = F.interpolate(y, scale_factor=(1.7,2), mode='nearest', recompute_scale_factor=True) + y = F.interpolate(y, scale_factor=(2,1.2), mode='bilinear', align_corners=False, recompute_scale_factor=True) + y = F.interpolate(y, scale_factor=(0.5,0.4), mode='bilinear', align_corners=True, recompute_scale_factor=True) + y = F.interpolate(y, scale_factor=(0.8,0.9), mode='bicubic', align_corners=False, recompute_scale_factor=True) + y = F.interpolate(y, scale_factor=(1.1,0.5), mode='bicubic', align_corners=True, recompute_scale_factor=True) + + z = F.interpolate(z, size=16) + z = F.interpolate(z, scale_factor=2, mode='nearest') + z = F.interpolate(z, size=(20,20,20), mode='nearest') + z = F.interpolate(z, scale_factor=(4,4,4), mode='nearest') + z = F.interpolate(z, size=(16,24,20), mode='nearest') + z = F.interpolate(z, scale_factor=(2,3,4), mode='nearest') + z = F.interpolate(z, size=16, mode='trilinear') + z = F.interpolate(z, scale_factor=2, mode='trilinear') + z = F.interpolate(z, size=(20,20,20), mode='trilinear', align_corners=False) + z = F.interpolate(z, scale_factor=(4,4,4), mode='trilinear', align_corners=False) + z = F.interpolate(z, size=(16,24,20), mode='trilinear', align_corners=True) + z = F.interpolate(z, scale_factor=(2,3,4), mode='trilinear', align_corners=True) + + z = F.interpolate(z, scale_factor=(1.5,2.5,2), mode='nearest', recompute_scale_factor=True) + z = F.interpolate(z, scale_factor=(0.7,0.5,1), mode='trilinear', align_corners=False, recompute_scale_factor=True) + z = F.interpolate(z, scale_factor=(0.9,0.8,1.2), mode='trilinear', align_corners=True, recompute_scale_factor=True) + + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32) + y = torch.rand(1, 3, 32, 32) + z = torch.rand(1, 3, 32, 32, 32) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_interpolate.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_interpolate.pt inputshape=[1,3,32],[1,3,32,32],[1,3,32,32,32]") + + # pnnx inference + import test_F_interpolate_pnnx + b0, b1, b2 = test_F_interpolate_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_layer_norm.py b/tools/pnnx/tests/test_F_layer_norm.py new file mode 100644 index 000000000000..32840f5cd2bb --- /dev/null +++ b/tools/pnnx/tests/test_F_layer_norm.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w0, b0, w1, b1, w2, b2): + x = F.layer_norm(x, (24,), w0, b0) + x = F.layer_norm(x, (12,24), None, None) + + y = F.layer_norm(y, (16,), None, None, eps=1e-3) + y = F.layer_norm(y, (12,16), w1, b1) + + z = F.layer_norm(z, (24,), w2, b2) + z = F.layer_norm(z, (12,16,24), None, None, eps=1e-2) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24) + y = torch.rand(2, 3, 12, 16) + z = torch.rand(1, 10, 12, 16, 24) + w0 = torch.rand(24) + b0 = torch.rand(24) + w1 = torch.rand(12, 16) + b1 = torch.rand(12, 16) + w2 = torch.rand(24) + b2 = torch.rand(24) + + a0, a1, a2 = net(x, y, z, w0, b0, w1, b1, w2, b2) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w0, b0, w1, b1, w2, b2)) + mod.save("test_F_layer_norm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_layer_norm.pt inputshape=[1,12,24],[2,3,12,16],[1,10,12,16,24],[24],[24],[12,16],[12,16],[24],[24]") + + # pnnx inference + import test_F_layer_norm_pnnx + b0, b1, b2 = test_F_layer_norm_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_leaky_relu.py b/tools/pnnx/tests/test_F_leaky_relu.py new file mode 100644 index 000000000000..700d78dafa36 --- /dev/null +++ b/tools/pnnx/tests/test_F_leaky_relu.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.leaky_relu(x) + y = F.leaky_relu(y, 0.1) + z = F.leaky_relu(z, -0.22) + w = F.leaky_relu(w, 0) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_leaky_relu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_leaky_relu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_leaky_relu_pnnx + b0, b1, b2, b3 = test_F_leaky_relu_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_linear.py b/tools/pnnx/tests/test_F_linear.py new file mode 100644 index 000000000000..41c44feba779 --- /dev/null +++ b/tools/pnnx/tests/test_F_linear.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w0, w1, b1): + x = F.linear(x, w0, None) + x = F.linear(x, w1, b1) + + y = F.linear(y, w0, None) + y = F.linear(y, w1, b1) + + z = F.linear(z, w0, None) + z = F.linear(z, w1, b1) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w0 = torch.rand(12, 16) + w1 = torch.rand(32, 12) + b1 = torch.rand(32) + + a0, a1, a2 = net(x, y, z, w0, w1, b1) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w0, w1, b1)) + mod.save("test_F_linear.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_linear.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[12,16],[32,12],[32]") + + # pnnx inference + import test_F_linear_pnnx + b0, b1, b2 = test_F_linear_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_local_response_norm.py b/tools/pnnx/tests/test_F_local_response_norm.py new file mode 100644 index 000000000000..35969269ef66 --- /dev/null +++ b/tools/pnnx/tests/test_F_local_response_norm.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.local_response_norm(x, 3) + x = F.local_response_norm(x, size=5, alpha=0.001, beta=0.8, k=0.9) + + y = F.local_response_norm(y, 4) + y = F.local_response_norm(y, size=4, alpha=0.01, beta=0.2, k=1.9) + + z = F.local_response_norm(z, 5) + z = F.local_response_norm(z, size=3, alpha=0.1, beta=0.3, k=0.2) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24) + y = torch.rand(2, 3, 12, 16) + z = torch.rand(1, 10, 12, 16, 24) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_local_response_norm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_local_response_norm.pt inputshape=[1,12,24],[2,3,12,16],[1,10,12,16,24]") + + # pnnx inference + import test_F_local_response_norm_pnnx + b0, b1, b2 = test_F_local_response_norm_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_log_softmax.py b/tools/pnnx/tests/test_F_log_softmax.py new file mode 100644 index 000000000000..5906a1be9967 --- /dev/null +++ b/tools/pnnx/tests/test_F_log_softmax.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.log_softmax(x, 1) + y = F.log_softmax(y, 0) + z = F.log_softmax(z, 2) + w = F.log_softmax(w, 3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_log_softmax.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_log_softmax.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_log_softmax_pnnx + b0, b1, b2, b3 = test_F_log_softmax_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_logsigmoid.py b/tools/pnnx/tests/test_F_logsigmoid.py new file mode 100644 index 000000000000..096c2ab5254e --- /dev/null +++ b/tools/pnnx/tests/test_F_logsigmoid.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.logsigmoid(x) + y = F.logsigmoid(y) + z = F.logsigmoid(z) + w = F.logsigmoid(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_logsigmoid.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_logsigmoid.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_logsigmoid_pnnx + b0, b1, b2, b3 = test_F_logsigmoid_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_lp_pool1d.py b/tools/pnnx/tests/test_F_lp_pool1d.py new file mode 100644 index 000000000000..00ad63cbd73a --- /dev/null +++ b/tools/pnnx/tests/test_F_lp_pool1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.lp_pool1d(x, norm_type=2, kernel_size=3) + x = F.lp_pool1d(x, norm_type=2, kernel_size=4, stride=2) + x = F.lp_pool1d(x, norm_type=1, kernel_size=3, stride=1, ceil_mode=False) + x = F.lp_pool1d(x, norm_type=1, kernel_size=5, stride=1, ceil_mode=True) + x = F.lp_pool1d(x, norm_type=1.2, kernel_size=3, stride=2, ceil_mode=False) + x = F.lp_pool1d(x, norm_type=0.5, kernel_size=2, stride=1, ceil_mode=True) + x = F.lp_pool1d(x, norm_type=0.1, kernel_size=4, stride=1, ceil_mode=False) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_lp_pool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_lp_pool1d.pt inputshape=[1,12,128]") + + # pnnx inference + import test_F_lp_pool1d_pnnx + b = test_F_lp_pool1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_lp_pool2d.py b/tools/pnnx/tests/test_F_lp_pool2d.py new file mode 100644 index 000000000000..fd387c37e46c --- /dev/null +++ b/tools/pnnx/tests/test_F_lp_pool2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.lp_pool2d(x, norm_type=2, kernel_size=3) + x = F.lp_pool2d(x, norm_type=2, kernel_size=4, stride=2) + x = F.lp_pool2d(x, norm_type=1, kernel_size=(1,3), stride=1, ceil_mode=False) + x = F.lp_pool2d(x, norm_type=1, kernel_size=(4,5), stride=(1,2), ceil_mode=True) + x = F.lp_pool2d(x, norm_type=1.2, kernel_size=(5,3), stride=(2,1), ceil_mode=False) + x = F.lp_pool2d(x, norm_type=0.5, kernel_size=2, stride=1, ceil_mode=True) + x = F.lp_pool2d(x, norm_type=0.1, kernel_size=(5,4), stride=1, ceil_mode=False) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_lp_pool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_lp_pool2d.pt inputshape=[1,12,128,128]") + + # pnnx inference + import test_F_lp_pool2d_pnnx + b = test_F_lp_pool2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_max_pool1d.py b/tools/pnnx/tests/test_F_max_pool1d.py new file mode 100644 index 000000000000..c3e3b02527ea --- /dev/null +++ b/tools/pnnx/tests/test_F_max_pool1d.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.max_pool1d(x, kernel_size=3) + x = F.max_pool1d(x, kernel_size=4, stride=2, padding=2, dilation=1) + x = F.max_pool1d(x, kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) + x = F.max_pool1d(x, kernel_size=5, stride=2, padding=2, dilation=1, return_indices=False, ceil_mode=True) + x = F.max_pool1d(x, kernel_size=3, stride=1, padding=1, dilation=2, return_indices=False, ceil_mode=False) + x = F.max_pool1d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + x, indices1 = F.max_pool1d(x, kernel_size=2, padding=1, dilation=1, return_indices=True, ceil_mode=False) + x, indices2 = F.max_pool1d(x, kernel_size=5, stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=True) + return x, indices1, indices2 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128) + + a0, a1, a2 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_max_pool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_max_pool1d.pt inputshape=[1,12,128]") + + # pnnx inference + import test_F_max_pool1d_pnnx + b0, b1, b2 = test_F_max_pool1d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_max_pool2d.py b/tools/pnnx/tests/test_F_max_pool2d.py new file mode 100644 index 000000000000..5b7f9722d88b --- /dev/null +++ b/tools/pnnx/tests/test_F_max_pool2d.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.max_pool2d(x, kernel_size=3) + x = F.max_pool2d(x, kernel_size=4, stride=2, padding=2, dilation=1) + x = F.max_pool2d(x, kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, return_indices=False, ceil_mode=False) + x = F.max_pool2d(x, kernel_size=(4,5), stride=(1,2), padding=(1,2), dilation=1, return_indices=False, ceil_mode=True) + x = F.max_pool2d(x, kernel_size=(2,3), stride=1, padding=1, dilation=(1,2), return_indices=False, ceil_mode=False) + x = F.max_pool2d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + x, indices1 = F.max_pool2d(x, kernel_size=2, padding=1, dilation=1, return_indices=True, ceil_mode=False) + x, indices2 = F.max_pool2d(x, kernel_size=(5,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) + return x, indices1, indices2 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128, 127) + + a0, a1, a2 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_max_pool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_max_pool2d.pt inputshape=[1,12,128,127]") + + # pnnx inference + import test_F_max_pool2d_pnnx + b0, b1, b2 = test_F_max_pool2d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_max_pool3d.py b/tools/pnnx/tests/test_F_max_pool3d.py new file mode 100644 index 000000000000..d82087f00c64 --- /dev/null +++ b/tools/pnnx/tests/test_F_max_pool3d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.max_pool3d(x, kernel_size=3) + x = F.max_pool3d(x, kernel_size=4, stride=2, padding=2, dilation=1) + x = F.max_pool3d(x, kernel_size=(1,2,3), stride=1, padding=(0,0,1), dilation=1, return_indices=False, ceil_mode=False) + x = F.max_pool3d(x, kernel_size=(3,4,5), stride=(1,2,2), padding=(1,2,2), dilation=1, return_indices=False, ceil_mode=True) + x = F.max_pool3d(x, kernel_size=(2,3,3), stride=1, padding=1, dilation=(1,2,2), return_indices=False, ceil_mode=False) + x = F.max_pool3d(x, kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + x, indices = F.max_pool3d(x, kernel_size=(5,4,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 96, 128, 128) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_max_pool3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_max_pool3d.pt inputshape=[1,12,96,128,128]") + + # pnnx inference + import test_F_max_pool3d_pnnx + b0, b1 = test_F_max_pool3d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_mish.py b/tools/pnnx/tests/test_F_mish.py new file mode 100644 index 000000000000..a4bf52c5631e --- /dev/null +++ b/tools/pnnx/tests/test_F_mish.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def mish_forward_0(x): + return x * F.softplus(x).tanh() + +def mish_forward_1(x): + return x.mul(torch.tanh(F.softplus(x))) + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.mish(x) + y = F.mish(y) + z = mish_forward_0(z) + w = mish_forward_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_mish.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_mish.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_mish_pnnx + b = test_F_mish_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_normalize.py b/tools/pnnx/tests/test_F_normalize.py new file mode 100644 index 000000000000..8c308600a769 --- /dev/null +++ b/tools/pnnx/tests/test_F_normalize.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.normalize(x) + x = F.normalize(x, eps=1e-3) + + y = F.normalize(y, p=1, dim=1) + y = F.normalize(y, dim=2) + + z = F.normalize(z) + z = F.normalize(z, dim=2, eps=1e-4) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 12, 24, 64) + z = torch.rand(1, 12, 16, 24, 64) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_normalize.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_normalize.pt inputshape=[1,24,64],[1,12,24,64],[1,12,16,24,64]") + + # pnnx inference + import test_F_normalize_pnnx + b0, b1, b2 = test_F_normalize_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_pad.py b/tools/pnnx/tests/test_F_pad.py new file mode 100644 index 000000000000..3b61dd6014aa --- /dev/null +++ b/tools/pnnx/tests/test_F_pad.py @@ -0,0 +1,74 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.pad(x, (3,4), mode='constant', value=1.3) + x = F.pad(x, (2,2)) + + y = F.pad(y, (5,6), mode='reflect') + y = F.pad(y, (2,1), mode='replicate') + y = F.pad(y, (3,4), mode='constant', value=1.3) + y = F.pad(y, (1,1)) + + z = F.pad(z, (3,4,3,4), mode='reflect') + z = F.pad(z, (2,1,2,0), mode='replicate') + z = F.pad(z, (1,0,2,0), mode='constant', value=1.3) + z = F.pad(z, (3,3,3,3)) + + #w = F.pad(w, (1,2,3,4,5,6), mode='reflect') + w = F.pad(w, (5,0,1,2,0,2), mode='replicate') + w = F.pad(w, (0,2,2,1,3,4), mode='constant', value=1.3) + w = F.pad(w, (2,2,2,2,2,2)) + + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_pad.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_pad.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_pad_pnnx + b0, b1, b2, b3 = test_F_pad_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_pixel_shuffle.py b/tools/pnnx/tests/test_F_pixel_shuffle.py new file mode 100644 index 000000000000..4e5d5c023565 --- /dev/null +++ b/tools/pnnx/tests/test_F_pixel_shuffle.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.pixel_shuffle(x, 2) + x = F.pixel_shuffle(x, 4) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 6, 7) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_pixel_shuffle.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_pixel_shuffle.pt inputshape=[1,128,6,7]") + + # pnnx inference + import test_F_pixel_shuffle_pnnx + b = test_F_pixel_shuffle_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_pixel_unshuffle.py b/tools/pnnx/tests/test_F_pixel_unshuffle.py new file mode 100644 index 000000000000..1d64e0eb6719 --- /dev/null +++ b/tools/pnnx/tests/test_F_pixel_unshuffle.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.pixel_unshuffle(x, 4) + x = F.pixel_unshuffle(x, 2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 128, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_pixel_unshuffle.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_pixel_unshuffle.pt inputshape=[1,3,128,128]") + + # pnnx inference + import test_F_pixel_unshuffle_pnnx + b = test_F_pixel_unshuffle_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_prelu.py b/tools/pnnx/tests/test_F_prelu.py new file mode 100644 index 000000000000..c6595f9259fb --- /dev/null +++ b/tools/pnnx/tests/test_F_prelu.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w, w0, w1, w2, w3): + x = F.prelu(x, w0) + y = F.prelu(y, w1) + z = F.prelu(z, w2) + w = F.prelu(w, w3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + w0 = torch.rand(16) + w1 = torch.rand(2) + w2 = torch.rand(3) + w3 = torch.rand(1) + + a0, a1, a2, a3 = net(x, y, z, w, w0, w1, w2, w3) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w, w0, w1, w2, w3)) + mod.save("test_F_prelu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_prelu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11],[16],[2],[3],[1]") + + # pnnx inference + import test_F_prelu_pnnx + b0, b1, b2, b3 = test_F_prelu_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_relu.py b/tools/pnnx/tests/test_F_relu.py new file mode 100644 index 000000000000..0319948f7f40 --- /dev/null +++ b/tools/pnnx/tests/test_F_relu.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.relu(x) + y = F.relu(y) + z = F.relu(z) + w = F.relu(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_relu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_relu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_relu_pnnx + b0, b1, b2, b3 = test_F_relu_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_relu6.py b/tools/pnnx/tests/test_F_relu6.py new file mode 100644 index 000000000000..147d25002b70 --- /dev/null +++ b/tools/pnnx/tests/test_F_relu6.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.relu6(x) + y = F.relu6(y) + z = F.relu6(z) + w = F.relu6(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_relu6.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_relu6.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_relu6_pnnx + b0, b1, b2, b3 = test_F_relu6_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_rrelu.py b/tools/pnnx/tests/test_F_rrelu.py new file mode 100644 index 000000000000..3dee3fe5e668 --- /dev/null +++ b/tools/pnnx/tests/test_F_rrelu.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.rrelu(x) + y = F.rrelu(y, 0.01) + z = F.rrelu(z, 0.125, 0.3333) + w = F.rrelu(w, 0.4, 0.4) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_rrelu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_rrelu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_rrelu_pnnx + b0, b1, b2, b3 = test_F_rrelu_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_selu.py b/tools/pnnx/tests/test_F_selu.py new file mode 100644 index 000000000000..10e3bc4bc57e --- /dev/null +++ b/tools/pnnx/tests/test_F_selu.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.selu(x) + y = F.selu(y) + z = F.selu(z) + w = F.selu(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_selu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_selu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_selu_pnnx + b0, b1, b2, b3 = test_F_selu_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_sigmoid.py b/tools/pnnx/tests/test_F_sigmoid.py new file mode 100644 index 000000000000..282f09ec865e --- /dev/null +++ b/tools/pnnx/tests/test_F_sigmoid.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.sigmoid(x) + y = F.sigmoid(y) + z = F.sigmoid(z) + w = F.sigmoid(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_sigmoid.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_sigmoid.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_sigmoid_pnnx + b0, b1, b2, b3 = test_F_sigmoid_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_silu.py b/tools/pnnx/tests/test_F_silu.py new file mode 100644 index 000000000000..21a124a8e92b --- /dev/null +++ b/tools/pnnx/tests/test_F_silu.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def silu_forward_0(x): + return x * torch.sigmoid(x) + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.silu(x) + y = F.silu(y) + z = F.silu(z) + w = silu_forward_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_silu.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_silu.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_silu_pnnx + b = test_F_silu_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_softmax.py b/tools/pnnx/tests/test_F_softmax.py new file mode 100644 index 000000000000..c3110f316ef6 --- /dev/null +++ b/tools/pnnx/tests/test_F_softmax.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.softmax(x, 1) + y = F.softmax(y, 0) + z = F.softmax(z, 2) + w = F.softmax(w, 3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_softmax.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_softmax.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softmax_pnnx + b0, b1, b2, b3 = test_F_softmax_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_softmin.py b/tools/pnnx/tests/test_F_softmin.py new file mode 100644 index 000000000000..98e4a9e2a219 --- /dev/null +++ b/tools/pnnx/tests/test_F_softmin.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.softmin(x, 1) + y = F.softmin(y, 0) + z = F.softmin(z, 2) + w = F.softmin(w, 3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_softmin.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_softmin.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softmin_pnnx + b0, b1, b2, b3 = test_F_softmin_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_softplus.py b/tools/pnnx/tests/test_F_softplus.py new file mode 100644 index 000000000000..dd2986b7e71b --- /dev/null +++ b/tools/pnnx/tests/test_F_softplus.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.softplus(x) + y = F.softplus(y, 2, 1.2) + z = F.softplus(z, -0.7, 15) + w = F.softplus(w, 0.1, 0.3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_softplus.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_softplus.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softplus_pnnx + b0, b1, b2, b3 = test_F_softplus_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_softshrink.py b/tools/pnnx/tests/test_F_softshrink.py new file mode 100644 index 000000000000..3bf8443c6704 --- /dev/null +++ b/tools/pnnx/tests/test_F_softshrink.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.softshrink(x) + y = F.softshrink(y, 0.1) + z = F.softshrink(z, 0.22) + w = F.softshrink(w, 0) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_softshrink.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_softshrink.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softshrink_pnnx + b0, b1, b2, b3 = test_F_softshrink_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_softsign.py b/tools/pnnx/tests/test_F_softsign.py new file mode 100644 index 000000000000..eb4b7c41d522 --- /dev/null +++ b/tools/pnnx/tests/test_F_softsign.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.softsign(x) + y = F.softsign(y) + z = F.softsign(z) + w = F.softsign(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_softsign.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_softsign.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softsign_pnnx + b0, b1, b2, b3 = test_F_softsign_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_tanh.py b/tools/pnnx/tests/test_F_tanh.py new file mode 100644 index 000000000000..800d558caa6d --- /dev/null +++ b/tools/pnnx/tests/test_F_tanh.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.tanh(x) + y = F.tanh(y) + z = F.tanh(z) + w = F.tanh(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_tanh.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_tanh.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_tanh_pnnx + b0, b1, b2, b3 = test_F_tanh_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_tanhshrink.py b/tools/pnnx/tests/test_F_tanhshrink.py new file mode 100644 index 000000000000..da2602f9de5f --- /dev/null +++ b/tools/pnnx/tests/test_F_tanhshrink.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.tanhshrink(x) + y = F.tanhshrink(y) + z = F.tanhshrink(z) + w = F.tanhshrink(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_tanhshrink.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_tanhshrink.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_tanhshrink_pnnx + b0, b1, b2, b3 = test_F_tanhshrink_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_threshold.py b/tools/pnnx/tests/test_F_threshold.py new file mode 100644 index 000000000000..6ad2b8a89131 --- /dev/null +++ b/tools/pnnx/tests/test_F_threshold.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = F.threshold(x, 0.1, 20) + y = F.threshold(y, 0.3, 0.4) + z = F.threshold(z, 0.1, 20) + w = F.threshold(w, 0.3, 0.4) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_F_threshold.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_threshold.pt inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_threshold_pnnx + b0, b1, b2, b3 = test_F_threshold_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_upsample.py b/tools/pnnx/tests/test_F_upsample.py new file mode 100644 index 000000000000..cff972e4573b --- /dev/null +++ b/tools/pnnx/tests/test_F_upsample.py @@ -0,0 +1,96 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = F.upsample(x, size=16) + x = F.upsample(x, scale_factor=2, mode='nearest') + x = F.upsample(x, size=(20), mode='nearest') + x = F.upsample(x, scale_factor=(4), mode='nearest') + x = F.upsample(x, size=16, mode='linear') + x = F.upsample(x, scale_factor=2, mode='linear') + x = F.upsample(x, size=(24), mode='linear', align_corners=True) + x = F.upsample(x, scale_factor=(3), mode='linear', align_corners=True) + + y = F.upsample(y, size=16) + y = F.upsample(y, scale_factor=2, mode='nearest') + y = F.upsample(y, size=(20,20), mode='nearest') + y = F.upsample(y, scale_factor=(4,4), mode='nearest') + y = F.upsample(y, size=(16,24), mode='nearest') + y = F.upsample(y, scale_factor=(2,3), mode='nearest') + y = F.upsample(y, size=16, mode='bilinear') + y = F.upsample(y, scale_factor=2, mode='bilinear') + y = F.upsample(y, size=(20,20), mode='bilinear', align_corners=False) + y = F.upsample(y, scale_factor=(4,4), mode='bilinear', align_corners=False) + y = F.upsample(y, size=(16,24), mode='bilinear', align_corners=True) + y = F.upsample(y, scale_factor=(2,3), mode='bilinear', align_corners=True) + y = F.upsample(y, size=16, mode='bicubic') + y = F.upsample(y, scale_factor=2, mode='bicubic') + y = F.upsample(y, size=(20,20), mode='bicubic', align_corners=False) + y = F.upsample(y, scale_factor=(4,4), mode='bicubic', align_corners=False) + y = F.upsample(y, size=(16,24), mode='bicubic', align_corners=True) + y = F.upsample(y, scale_factor=(2,3), mode='bicubic', align_corners=True) + + z = F.upsample(z, size=16) + z = F.upsample(z, scale_factor=2, mode='nearest') + z = F.upsample(z, size=(20,20,20), mode='nearest') + z = F.upsample(z, scale_factor=(4,4,4), mode='nearest') + z = F.upsample(z, size=(16,24,20), mode='nearest') + z = F.upsample(z, scale_factor=(2,3,4), mode='nearest') + z = F.upsample(z, size=16, mode='trilinear') + z = F.upsample(z, scale_factor=2, mode='trilinear') + z = F.upsample(z, size=(20,20,20), mode='trilinear', align_corners=False) + z = F.upsample(z, scale_factor=(4,4,4), mode='trilinear', align_corners=False) + z = F.upsample(z, size=(16,24,20), mode='trilinear', align_corners=True) + z = F.upsample(z, scale_factor=(2,3,4), mode='trilinear', align_corners=True) + + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32) + y = torch.rand(1, 3, 32, 32) + z = torch.rand(1, 3, 32, 32, 32) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_F_upsample.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_upsample.pt inputshape=[1,3,32],[1,3,32,32],[1,3,32,32,32]") + + # pnnx inference + import test_F_upsample_pnnx + b0, b1, b2 = test_F_upsample_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_upsample_bilinear.py b/tools/pnnx/tests/test_F_upsample_bilinear.py new file mode 100644 index 000000000000..de2e3d4a5bc0 --- /dev/null +++ b/tools/pnnx/tests/test_F_upsample_bilinear.py @@ -0,0 +1,55 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = F.upsample_bilinear(x, size=(12,12)) + x = F.upsample_bilinear(x, scale_factor=2) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_F_upsample_bilinear.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_upsample_bilinear.pt inputshape=[1,12,24,64]") + + # pnnx inference + import test_F_upsample_bilinear_pnnx + b = test_F_upsample_bilinear_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_upsample_nearest.py b/tools/pnnx/tests/test_F_upsample_nearest.py new file mode 100644 index 000000000000..77509c9d7063 --- /dev/null +++ b/tools/pnnx/tests/test_F_upsample_nearest.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = F.upsample_nearest(x, size=(12,12)) + x = F.upsample_nearest(x, scale_factor=2) + + y = F.upsample_nearest(y, size=(8,10,9)) + y = F.upsample_nearest(y, scale_factor=3) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + y = torch.rand(1, 4, 10, 24, 32) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_F_upsample_nearest.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_upsample_nearest.pt inputshape=[1,12,24,64],[1,4,10,24,32]") + + # pnnx inference + import test_F_upsample_nearest_pnnx + b0, b1 = test_F_upsample_nearest_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_contiguous.py b/tools/pnnx/tests/test_Tensor_contiguous.py new file mode 100644 index 000000000000..38b059717b7e --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_contiguous.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = x.contiguous(memory_format=torch.contiguous_format) + y = y.contiguous(memory_format=torch.channels_last) + z = z.contiguous(memory_format=torch.preserve_format) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_Tensor_contiguous.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_contiguous.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_Tensor_contiguous_pnnx + b = test_Tensor_contiguous_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_new_empty.py b/tools/pnnx/tests/test_Tensor_new_empty.py new file mode 100644 index 000000000000..bc75a0059c94 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_new_empty.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + out0 = x.new_empty((2,2)) + out1 = x.new_empty(3) + out2 = x.new_empty((4,5,6,7,8)) + out3 = x.new_empty((1,2,1)) + out4 = x.new_empty((3,3,3,3)) + return out0, out1, out2, out3, out4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_Tensor_new_empty.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_new_empty.pt inputshape=[1,16]") + + # pnnx inference + import test_Tensor_new_empty_pnnx + b = test_Tensor_new_empty_pnnx.test_inference() + + # test shape only for uninitialized data + for a0, b0 in zip(a, b): + if not a0.shape == b0.shape: + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_repeat.py b/tools/pnnx/tests/test_Tensor_repeat.py new file mode 100644 index 000000000000..d20de55b8f28 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_repeat.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = x.repeat(1, 2, 3) + x = x.repeat(2, 3, 4) + y = y.repeat(1, 2, 1, 4) + y = y.repeat(3, 4, 5, 1) + z = z.repeat(1, 2, 3, 1, 5) + z = z.repeat(2, 3, 3, 1, 1) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_Tensor_repeat.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_repeat.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_Tensor_repeat_pnnx + b = test_Tensor_repeat_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_reshape.py b/tools/pnnx/tests/test_Tensor_reshape.py new file mode 100644 index 000000000000..9a082f97c1e1 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_reshape.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = x.reshape(1, 2, 24) + x = x.reshape(48) + y = y.reshape(1, 11, 5, 9) + y = y.reshape(99, 5) + z = z.reshape(4, 3, 30, 10, 14) + z = z.reshape(15, 2, 10, 7, 8, 3) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_Tensor_reshape.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_reshape.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_Tensor_reshape_pnnx + b = test_Tensor_reshape_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_select.py b/tools/pnnx/tests/test_Tensor_select.py new file mode 100644 index 000000000000..399555bace5f --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_select.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = x.select(1, 1) + y = y.select(2, 4) + z = z.select(0, 10) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_Tensor_select.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_select.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_Tensor_select_pnnx + b = test_Tensor_select_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_slice.py b/tools/pnnx/tests/test_Tensor_slice.py new file mode 100644 index 000000000000..273f7222ee65 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_slice.py @@ -0,0 +1,67 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = x[:,:12,1:14:2] + x = x[...,1:] + x = x[:,:,:x.size(2)-1] + y = y[0:,1:,5:,3:] + y = y[:,:,1:13:2,:14] + y = y[:1,:y.size(1):,:,:] + z = z[4:] + z = z[:2,:,:,:,2:-2:3] + z = z[:,:,z.size(3)-3:,:,:] + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 13, 26) + y = torch.rand(1, 15, 19, 21) + z = torch.rand(14, 18, 15, 19, 20) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_Tensor_slice.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_slice.pt inputshape=[1,13,26],[1,15,19,21],[14,18,15,19,20]") + + # pnnx inference + import test_Tensor_slice_pnnx + b = test_Tensor_slice_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_Tensor_view.py b/tools/pnnx/tests/test_Tensor_view.py new file mode 100644 index 000000000000..f897910ce8b2 --- /dev/null +++ b/tools/pnnx/tests/test_Tensor_view.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = x.view(1, 2, 24) + x = x.view(48) + y = y.view(1, 11, 5, 9) + y = y.view(99, 5) + z = z.view(4, 3, 30, 10, 14) + z = z.view(15, 2, 10, 7, 8, 3) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_Tensor_view.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_Tensor_view.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_Tensor_view_pnnx + b = test_Tensor_view_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_mobilenet_v2.py b/tools/pnnx/tests/test_mobilenet_v2.py new file mode 100644 index 000000000000..5becd50ba3e9 --- /dev/null +++ b/tools/pnnx/tests/test_mobilenet_v2.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.mobilenet_v2() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_mobilenet_v2.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_mobilenet_v2.pt inputshape=[1,3,224,224]") + + # pnnx inference + import test_mobilenet_v2_pnnx + b = test_mobilenet_v2_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_mobilenet_v3_small.py b/tools/pnnx/tests/test_mobilenet_v3_small.py new file mode 100644 index 000000000000..f8d766b11c2c --- /dev/null +++ b/tools/pnnx/tests/test_mobilenet_v3_small.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.mobilenet_v3_small() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_mobilenet_v3_small.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_mobilenet_v3_small.pt inputshape=[1,3,224,224]") + + # pnnx inference + import test_mobilenet_v3_small_pnnx + b = test_mobilenet_v3_small_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AdaptiveAvgPool1d.py b/tools/pnnx/tests/test_nn_AdaptiveAvgPool1d.py new file mode 100644 index 000000000000..0e491c397307 --- /dev/null +++ b/tools/pnnx/tests/test_nn_AdaptiveAvgPool1d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AdaptiveAvgPool1d(output_size=(7)) + self.pool_1 = nn.AdaptiveAvgPool1d(output_size=1) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AdaptiveAvgPool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AdaptiveAvgPool1d.pt inputshape=[1,128,13]") + + # pnnx inference + import test_nn_AdaptiveAvgPool1d_pnnx + b = test_nn_AdaptiveAvgPool1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AdaptiveAvgPool2d.py b/tools/pnnx/tests/test_nn_AdaptiveAvgPool2d.py new file mode 100644 index 000000000000..c3daf5579684 --- /dev/null +++ b/tools/pnnx/tests/test_nn_AdaptiveAvgPool2d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AdaptiveAvgPool2d(output_size=(7,6)) + self.pool_1 = nn.AdaptiveAvgPool2d(output_size=1) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AdaptiveAvgPool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AdaptiveAvgPool2d.pt inputshape=[1,128,13,13]") + + # pnnx inference + import test_nn_AdaptiveAvgPool2d_pnnx + b = test_nn_AdaptiveAvgPool2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AdaptiveAvgPool3d.py b/tools/pnnx/tests/test_nn_AdaptiveAvgPool3d.py new file mode 100644 index 000000000000..9101c5dbb250 --- /dev/null +++ b/tools/pnnx/tests/test_nn_AdaptiveAvgPool3d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AdaptiveAvgPool3d(output_size=(7,6,5)) + self.pool_1 = nn.AdaptiveAvgPool3d(output_size=1) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 13, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AdaptiveAvgPool3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AdaptiveAvgPool3d.pt inputshape=[1,128,13,13,13]") + + # pnnx inference + import test_nn_AdaptiveAvgPool3d_pnnx + b = test_nn_AdaptiveAvgPool3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AdaptiveMaxPool1d.py b/tools/pnnx/tests/test_nn_AdaptiveMaxPool1d.py new file mode 100644 index 000000000000..87a8e26ffc65 --- /dev/null +++ b/tools/pnnx/tests/test_nn_AdaptiveMaxPool1d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AdaptiveMaxPool1d(output_size=(7), return_indices=True) + self.pool_1 = nn.AdaptiveMaxPool1d(output_size=1) + + def forward(self, x): + x, indices = self.pool_0(x) + x = self.pool_1(x) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 13) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AdaptiveMaxPool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AdaptiveMaxPool1d.pt inputshape=[1,128,13]") + + # pnnx inference + import test_nn_AdaptiveMaxPool1d_pnnx + b0, b1 = test_nn_AdaptiveMaxPool1d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AdaptiveMaxPool2d.py b/tools/pnnx/tests/test_nn_AdaptiveMaxPool2d.py new file mode 100644 index 000000000000..1c5ed189107c --- /dev/null +++ b/tools/pnnx/tests/test_nn_AdaptiveMaxPool2d.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AdaptiveMaxPool2d(output_size=(7,6), return_indices=True) + self.pool_1 = nn.AdaptiveMaxPool2d(output_size=1) + + def forward(self, x): + x, indices = self.pool_0(x) + x = self.pool_1(x) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 13, 13) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AdaptiveMaxPool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AdaptiveMaxPool2d.pt inputshape=[1,128,13,13]") + + # pnnx inference + import test_nn_AdaptiveMaxPool2d_pnnx + b0, b1 = test_nn_AdaptiveMaxPool2d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AdaptiveMaxPool3d.py b/tools/pnnx/tests/test_nn_AdaptiveMaxPool3d.py new file mode 100644 index 000000000000..83f5eff2f93d --- /dev/null +++ b/tools/pnnx/tests/test_nn_AdaptiveMaxPool3d.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + #self.pool_0 = nn.AdaptiveAvgPool3d(output_size=(7,6,5), return_indices=True) + self.pool_0 = nn.AdaptiveAvgPool3d(output_size=(7,6,5)) + self.pool_1 = nn.AdaptiveAvgPool3d(output_size=1) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 13, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AdaptiveAvgPool3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AdaptiveAvgPool3d.pt inputshape=[1,128,13,13,13]") + + # pnnx inference + import test_nn_AdaptiveAvgPool3d_pnnx + b = test_nn_AdaptiveAvgPool3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AvgPool1d.py b/tools/pnnx/tests/test_nn_AvgPool1d.py new file mode 100644 index 000000000000..cb784dc238ee --- /dev/null +++ b/tools/pnnx/tests/test_nn_AvgPool1d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AvgPool1d(kernel_size=3) + self.pool_1 = nn.AvgPool1d(kernel_size=4, stride=2, padding=2) + self.pool_2 = nn.AvgPool1d(kernel_size=3, stride=1, padding=(0), ceil_mode=False, count_include_pad=True) + self.pool_3 = nn.AvgPool1d(kernel_size=5, stride=2, padding=(2), ceil_mode=True, count_include_pad=False) + self.pool_4 = nn.AvgPool1d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=True) + self.pool_5 = nn.AvgPool1d(kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + self.pool_6 = nn.AvgPool1d(kernel_size=4, stride=1, padding=2, ceil_mode=False, count_include_pad=False) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x = self.pool_6(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AvgPool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AvgPool1d.pt inputshape=[1,12,128]") + + # pnnx inference + import test_nn_AvgPool1d_pnnx + b = test_nn_AvgPool1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AvgPool2d.py b/tools/pnnx/tests/test_nn_AvgPool2d.py new file mode 100644 index 000000000000..458463a3d438 --- /dev/null +++ b/tools/pnnx/tests/test_nn_AvgPool2d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AvgPool2d(kernel_size=3) + self.pool_1 = nn.AvgPool2d(kernel_size=4, stride=2, padding=2) + self.pool_2 = nn.AvgPool2d(kernel_size=(1,3), stride=1, padding=(0,1), ceil_mode=False, count_include_pad=True) + self.pool_3 = nn.AvgPool2d(kernel_size=(4,5), stride=(1,2), padding=(1,2), ceil_mode=True, count_include_pad=False) + self.pool_4 = nn.AvgPool2d(kernel_size=(5,3), stride=(2,1), padding=1, ceil_mode=False, count_include_pad=True) + self.pool_5 = nn.AvgPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + self.pool_6 = nn.AvgPool2d(kernel_size=(5,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=18) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x = self.pool_6(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AvgPool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AvgPool2d.pt inputshape=[1,12,128,128]") + + # pnnx inference + import test_nn_AvgPool2d_pnnx + b = test_nn_AvgPool2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_AvgPool3d.py b/tools/pnnx/tests/test_nn_AvgPool3d.py new file mode 100644 index 000000000000..486f0e603742 --- /dev/null +++ b/tools/pnnx/tests/test_nn_AvgPool3d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.AvgPool3d(kernel_size=3) + self.pool_1 = nn.AvgPool3d(kernel_size=4, stride=2, padding=2) + self.pool_2 = nn.AvgPool3d(kernel_size=(1,2,3), stride=1, padding=(0,1,1), ceil_mode=False, count_include_pad=True) + self.pool_3 = nn.AvgPool3d(kernel_size=(3,4,5), stride=(1,2,2), padding=(1,1,2), ceil_mode=True, count_include_pad=False) + self.pool_4 = nn.AvgPool3d(kernel_size=(5,4,3), stride=(2,1,1), padding=1, ceil_mode=False, count_include_pad=True) + self.pool_5 = nn.AvgPool3d(kernel_size=2, stride=1, padding=0, ceil_mode=True, count_include_pad=True) + self.pool_6 = nn.AvgPool3d(kernel_size=(5,4,4), stride=1, padding=2, ceil_mode=False, count_include_pad=False, divisor_override=77) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x = self.pool_6(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 96, 128, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_AvgPool3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_AvgPool3d.pt inputshape=[1,12,96,128,128]") + + # pnnx inference + import test_nn_AvgPool3d_pnnx + b = test_nn_AvgPool3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_BatchNorm1d.py b/tools/pnnx/tests/test_nn_BatchNorm1d.py new file mode 100644 index 000000000000..409660d02a05 --- /dev/null +++ b/tools/pnnx/tests/test_nn_BatchNorm1d.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.bn_0 = nn.BatchNorm1d(num_features=32) + self.bn_1 = nn.BatchNorm1d(num_features=32, eps=1e-1, affine=False) + self.bn_2 = nn.BatchNorm1d(num_features=11, affine=True) + + def forward(self, x, y): + x = self.bn_0(x) + x = self.bn_1(x) + + y = self.bn_2(y) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 32, 64) + y = torch.rand(1, 11, 1) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_BatchNorm1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_BatchNorm1d.pt inputshape=[1,32,64],[1,11,1]") + + # pnnx inference + import test_nn_BatchNorm1d_pnnx + b0, b1 = test_nn_BatchNorm1d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_BatchNorm2d.py b/tools/pnnx/tests/test_nn_BatchNorm2d.py new file mode 100644 index 000000000000..f76fb707703f --- /dev/null +++ b/tools/pnnx/tests/test_nn_BatchNorm2d.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.bn_0 = nn.BatchNorm2d(num_features=32) + self.bn_1 = nn.BatchNorm2d(num_features=32, eps=1e-1, affine=False) + self.bn_2 = nn.BatchNorm2d(num_features=11, affine=True) + + def forward(self, x, y): + x = self.bn_0(x) + x = self.bn_1(x) + + y = self.bn_2(y) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 32, 12, 64) + y = torch.rand(1, 11, 1, 1) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_BatchNorm2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_BatchNorm2d.pt inputshape=[1,32,12,64],[1,11,1,1]") + + # pnnx inference + import test_nn_BatchNorm2d_pnnx + b0, b1 = test_nn_BatchNorm2d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_BatchNorm3d.py b/tools/pnnx/tests/test_nn_BatchNorm3d.py new file mode 100644 index 000000000000..2875a49b71da --- /dev/null +++ b/tools/pnnx/tests/test_nn_BatchNorm3d.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.bn_0 = nn.BatchNorm3d(num_features=32) + self.bn_1 = nn.BatchNorm3d(num_features=32, eps=1e-1, affine=False) + self.bn_2 = nn.BatchNorm3d(num_features=11, affine=True) + + def forward(self, x, y): + x = self.bn_0(x) + x = self.bn_1(x) + + y = self.bn_2(y) + + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 32, 12, 5, 64) + y = torch.rand(1, 11, 3, 1, 1) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_BatchNorm3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_BatchNorm3d.pt inputshape=[1,32,12,5,64],[1,11,3,1,1]") + + # pnnx inference + import test_nn_BatchNorm3d_pnnx + b0, b1 = test_nn_BatchNorm3d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_CELU.py b/tools/pnnx/tests/test_nn_CELU.py new file mode 100644 index 000000000000..5eefc655124b --- /dev/null +++ b/tools/pnnx/tests/test_nn_CELU.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.CELU() + self.act_1 = nn.CELU(alpha=2.0) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_CELU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_CELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_CELU_pnnx + b0, b1, b2, b3 = test_nn_CELU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ChannelShuffle.py b/tools/pnnx/tests/test_nn_ChannelShuffle.py new file mode 100644 index 000000000000..cc2936d942e7 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ChannelShuffle.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.shuffle_0 = nn.ChannelShuffle(2) + self.shuffle_1 = nn.ChannelShuffle(16) + + def forward(self, x, y): + x = self.shuffle_0(x) + x = self.shuffle_1(x) + + y = self.shuffle_0(y) + y = self.shuffle_1(y) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 64, 6, 8) + y = torch.rand(1, 96, 7, 9) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_ChannelShuffle.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ChannelShuffle.pt inputshape=[1,64,6,8],[1,96,7,9]") + + # pnnx inference + import test_nn_ChannelShuffle_pnnx + b0, b1 = test_nn_ChannelShuffle_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ConstantPad1d.py b/tools/pnnx/tests/test_nn_ConstantPad1d.py new file mode 100644 index 000000000000..5878f1d1d7dd --- /dev/null +++ b/tools/pnnx/tests/test_nn_ConstantPad1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ConstantPad1d(2, 0.1) + self.pad_1 = nn.ConstantPad1d(padding=(3,4), value=-1) + self.pad_2 = nn.ConstantPad1d(padding=(1,0), value=123) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConstantPad1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ConstantPad1d.pt inputshape=[1,12,13]") + + # pnnx inference + import test_nn_ConstantPad1d_pnnx + b = test_nn_ConstantPad1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ConstantPad2d.py b/tools/pnnx/tests/test_nn_ConstantPad2d.py new file mode 100644 index 000000000000..70ad1520e1e1 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ConstantPad2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ConstantPad2d(2, 0.1) + self.pad_1 = nn.ConstantPad2d(padding=(3,4,5,6), value=-2) + self.pad_2 = nn.ConstantPad2d(padding=(1,0,2,0), value=0) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConstantPad2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ConstantPad2d.pt inputshape=[1,12,13,13]") + + # pnnx inference + import test_nn_ConstantPad2d_pnnx + b = test_nn_ConstantPad2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ConstantPad3d.py b/tools/pnnx/tests/test_nn_ConstantPad3d.py new file mode 100644 index 000000000000..a81cb0b5e891 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ConstantPad3d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ConstantPad3d(2, 0.3) + self.pad_1 = nn.ConstantPad3d(padding=(1,2,3,4,5,6), value=0) + self.pad_2 = nn.ConstantPad3d(padding=(1,0,2,0,0,3), value=1.1) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConstantPad3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ConstantPad3d.pt inputshape=[1,12,13,13,13]") + + # pnnx inference + import test_nn_ConstantPad3d_pnnx + b = test_nn_ConstantPad3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Conv1d.py b/tools/pnnx/tests/test_nn_Conv1d.py new file mode 100644 index 000000000000..34cf417bf0b2 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Conv1d.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.conv_0 = nn.Conv1d(in_channels=12, out_channels=16, kernel_size=3) + self.conv_1 = nn.Conv1d(in_channels=16, out_channels=20, kernel_size=2, stride=2, padding=2, dilation=1) + self.conv_2 = nn.Conv1d(in_channels=20, out_channels=24, kernel_size=3, stride=1, padding=(4), dilation=1, groups=1, bias=False) + if torch.__version__ < '1.9': + self.conv_3 = nn.Conv1d(in_channels=24, out_channels=28, kernel_size=5, stride=1, padding=0, dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv1d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=2, groups=2, bias=False, padding_mode='zeros') + else: + self.conv_3 = nn.Conv1d(in_channels=24, out_channels=28, kernel_size=5, stride=1, padding='valid', dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv1d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=2, groups=2, bias=False, padding_mode='zeros') + self.conv_5 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.conv_6 = nn.Conv1d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + #self.conv_7 = nn.Conv1d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') + + def forward(self, x): + x = self.conv_0(x) + x = self.conv_1(x) + x = self.conv_2(x) + x = self.conv_3(x) + x = self.conv_4(x) + x = self.conv_5(x) + x = self.conv_6(x) + #x = self.conv_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_Conv1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Conv1d.pt inputshape=[1,12,64]") + + # pnnx inference + import test_nn_Conv1d_pnnx + b = test_nn_Conv1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Conv2d.py b/tools/pnnx/tests/test_nn_Conv2d.py new file mode 100644 index 000000000000..6020dab00125 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Conv2d.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.conv_0 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3) + self.conv_1 = nn.Conv2d(in_channels=16, out_channels=20, kernel_size=(2,4), stride=(2,1), padding=2, dilation=1) + self.conv_2 = nn.Conv2d(in_channels=20, out_channels=24, kernel_size=(1,3), stride=1, padding=(2,4), dilation=1, groups=1, bias=False) + if torch.__version__ < '1.9': + self.conv_3 = nn.Conv2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=1, padding=0, dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=(1,2), groups=2, bias=False, padding_mode='zeros') + else: + self.conv_3 = nn.Conv2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=1, padding='valid', dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2), groups=2, bias=False, padding_mode='zeros') + self.conv_5 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.conv_6 = nn.Conv2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + #self.conv_7 = nn.Conv2d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') + + def forward(self, x): + x = self.conv_0(x) + x = self.conv_1(x) + x = self.conv_2(x) + x = self.conv_3(x) + x = self.conv_4(x) + x = self.conv_5(x) + x = self.conv_6(x) + #x = self.conv_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_Conv2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Conv2d.pt inputshape=[1,12,64,64]") + + # pnnx inference + import test_nn_Conv2d_pnnx + b = test_nn_Conv2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Conv3d.py b/tools/pnnx/tests/test_nn_Conv3d.py new file mode 100644 index 000000000000..6e739a8b9e5e --- /dev/null +++ b/tools/pnnx/tests/test_nn_Conv3d.py @@ -0,0 +1,75 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.conv_0 = nn.Conv3d(in_channels=12, out_channels=16, kernel_size=3) + self.conv_1 = nn.Conv3d(in_channels=16, out_channels=20, kernel_size=(2,3,4), stride=(2,2,1), padding=2, dilation=1) + self.conv_2 = nn.Conv3d(in_channels=20, out_channels=24, kernel_size=(1,2,3), stride=1, padding=(2,4,1), dilation=1, groups=1, bias=False) + if torch.__version__ < '1.9': + self.conv_3 = nn.Conv3d(in_channels=24, out_channels=28, kernel_size=(5,4,3), stride=1, padding=0, dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=(1,2,2), groups=2, bias=False, padding_mode='zeros') + else: + self.conv_3 = nn.Conv3d(in_channels=24, out_channels=28, kernel_size=(5,4,3), stride=1, padding='valid', dilation=1, groups=4, bias=True) + self.conv_4 = nn.Conv3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2,2), groups=2, bias=False, padding_mode='zeros') + #self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + #self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + #self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') + + def forward(self, x): + x = self.conv_0(x) + x = self.conv_1(x) + x = self.conv_2(x) + x = self.conv_3(x) + x = self.conv_4(x) + #x = self.conv_5(x) + #x = self.conv_6(x) + #x = self.conv_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 48, 48, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_Conv3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Conv3d.pt inputshape=[1,12,48,48,64]") + + # pnnx inference + import test_nn_Conv3d_pnnx + b = test_nn_Conv3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ConvTranspose1d.py b/tools/pnnx/tests/test_nn_ConvTranspose1d.py new file mode 100644 index 000000000000..2219accbc9e6 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ConvTranspose1d.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.deconv_0 = nn.ConvTranspose1d(in_channels=12, out_channels=16, kernel_size=3) + self.deconv_1 = nn.ConvTranspose1d(in_channels=16, out_channels=20, kernel_size=2, stride=2, padding=2, output_padding=0) + self.deconv_2 = nn.ConvTranspose1d(in_channels=20, out_channels=24, kernel_size=3, stride=1, padding=(2), output_padding=(0), dilation=1, groups=1, bias=False) + self.deconv_3 = nn.ConvTranspose1d(in_channels=24, out_channels=28, kernel_size=5, stride=2, padding=0, output_padding=(1), dilation=1, groups=4, bias=True) + self.deconv_4 = nn.ConvTranspose1d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, output_padding=0, dilation=2, groups=2, bias=False) + self.deconv_5 = nn.ConvTranspose1d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, output_padding=1, dilation=1, groups=32, bias=True) + self.deconv_6 = nn.ConvTranspose1d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) + self.deconv_7 = nn.ConvTranspose1d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(6), output_padding=(1), dilation=2, groups=1, bias=True) + + def forward(self, x): + x = self.deconv_0(x) + x = self.deconv_1(x) + x = self.deconv_2(x) + x = self.deconv_3(x) + x = self.deconv_4(x) + x = self.deconv_5(x) + x = self.deconv_6(x) + x = self.deconv_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 10) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConvTranspose1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ConvTranspose1d.pt inputshape=[1,12,10]") + + # pnnx inference + import test_nn_ConvTranspose1d_pnnx + b = test_nn_ConvTranspose1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ConvTranspose2d.py b/tools/pnnx/tests/test_nn_ConvTranspose2d.py new file mode 100644 index 000000000000..5b3acd808129 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ConvTranspose2d.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.deconv_0 = nn.ConvTranspose2d(in_channels=12, out_channels=16, kernel_size=3) + self.deconv_1 = nn.ConvTranspose2d(in_channels=16, out_channels=20, kernel_size=(2,4), stride=(2,1), padding=2, output_padding=0) + self.deconv_2 = nn.ConvTranspose2d(in_channels=20, out_channels=24, kernel_size=(1,3), stride=1, padding=(2,4), output_padding=(0,0), dilation=1, groups=1, bias=False) + self.deconv_3 = nn.ConvTranspose2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=2, padding=0, output_padding=(0,1), dilation=1, groups=4, bias=True) + self.deconv_4 = nn.ConvTranspose2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, output_padding=0, dilation=(1,2), groups=2, bias=False) + self.deconv_5 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, output_padding=1, dilation=1, groups=32, bias=True) + self.deconv_6 = nn.ConvTranspose2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) + self.deconv_7 = nn.ConvTranspose2d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), output_padding=(1,0), dilation=2, groups=1, bias=True) + + def forward(self, x): + x = self.deconv_0(x) + x = self.deconv_1(x) + x = self.deconv_2(x) + x = self.deconv_3(x) + x = self.deconv_4(x) + x = self.deconv_5(x) + x = self.deconv_6(x) + x = self.deconv_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 10, 10) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConvTranspose2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ConvTranspose2d.pt inputshape=[1,12,10,10]") + + # pnnx inference + import test_nn_ConvTranspose2d_pnnx + b = test_nn_ConvTranspose2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ConvTranspose3d.py b/tools/pnnx/tests/test_nn_ConvTranspose3d.py new file mode 100644 index 000000000000..de727228e9cc --- /dev/null +++ b/tools/pnnx/tests/test_nn_ConvTranspose3d.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.deconv_0 = nn.ConvTranspose3d(in_channels=12, out_channels=16, kernel_size=3) + self.deconv_1 = nn.ConvTranspose3d(in_channels=16, out_channels=20, kernel_size=(2,3,4), stride=(2,2,1), padding=2, output_padding=0) + self.deconv_2 = nn.ConvTranspose3d(in_channels=20, out_channels=24, kernel_size=(1,2,3), stride=1, padding=(2,3,4), output_padding=(0,0,0), dilation=1, groups=1, bias=False) + self.deconv_3 = nn.ConvTranspose3d(in_channels=24, out_channels=28, kernel_size=(5,4,3), stride=2, padding=0, output_padding=(0,1,1), dilation=1, groups=4, bias=True) + self.deconv_4 = nn.ConvTranspose3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, output_padding=0, dilation=(1,2,2), groups=2, bias=False) + self.deconv_5 = nn.ConvTranspose3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, output_padding=1, dilation=1, groups=32, bias=True) + self.deconv_6 = nn.ConvTranspose3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) + self.deconv_7 = nn.ConvTranspose3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6,7), output_padding=(1,0,1), dilation=2, groups=1, bias=True) + + def forward(self, x): + x = self.deconv_0(x) + x = self.deconv_1(x) + x = self.deconv_2(x) + x = self.deconv_3(x) + x = self.deconv_4(x) + x = self.deconv_5(x) + x = self.deconv_6(x) + x = self.deconv_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 7, 7, 10) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ConvTranspose3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ConvTranspose3d.pt inputshape=[1,12,7,7,10]") + + # pnnx inference + import test_nn_ConvTranspose3d_pnnx + b = test_nn_ConvTranspose3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ELU.py b/tools/pnnx/tests/test_nn_ELU.py new file mode 100644 index 000000000000..3d5d4a0ebd37 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ELU.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.ELU() + self.act_1 = nn.ELU(alpha=1.3) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_ELU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_ELU_pnnx + b0, b1, b2, b3 = test_nn_ELU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_GELU.py b/tools/pnnx/tests/test_nn_GELU.py new file mode 100644 index 000000000000..c11f34e6a65c --- /dev/null +++ b/tools/pnnx/tests/test_nn_GELU.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.GELU() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_GELU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_GELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_GELU_pnnx + b0, b1, b2, b3 = test_nn_GELU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_GRU.py b/tools/pnnx/tests/test_nn_GRU.py new file mode 100644 index 000000000000..b22343a65bca --- /dev/null +++ b/tools/pnnx/tests/test_nn_GRU.py @@ -0,0 +1,76 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.gru_0_0 = nn.GRU(input_size=32, hidden_size=16) + self.gru_0_1 = nn.GRU(input_size=16, hidden_size=16, num_layers=3, bias=False) + self.gru_0_2 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.gru_0_3 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + + self.gru_1_0 = nn.GRU(input_size=25, hidden_size=16, batch_first=True) + self.gru_1_1 = nn.GRU(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) + self.gru_1_2 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.gru_1_3 = nn.GRU(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x0, h0 = self.gru_0_0(x) + x1, h1 = self.gru_0_1(x0) + x2, h2 = self.gru_0_2(x1) + x3, h3 = self.gru_0_3(x1, h2) + + y0, h4 = self.gru_1_0(y) + y1, h5 = self.gru_1_1(y0) + y2, h6 = self.gru_1_2(y1) + y3, h7 = self.gru_1_3(y1, h6) + return x2, x3, h0, h1, h2, h3, y2, y3, h4, h5, h6, h7 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(10, 1, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_GRU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_GRU.pt inputshape=[10,1,32],[1,12,25]") + + # pnnx inference + import test_nn_GRU_pnnx + b = test_nn_GRU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_GroupNorm.py b/tools/pnnx/tests/test_nn_GroupNorm.py new file mode 100644 index 000000000000..bbfb59dfab62 --- /dev/null +++ b/tools/pnnx/tests/test_nn_GroupNorm.py @@ -0,0 +1,70 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.gn_0 = nn.GroupNorm(num_groups=4, num_channels=12) + self.gn_1 = nn.GroupNorm(num_groups=12, num_channels=12, eps=1e-2, affine=True) + self.gn_2 = nn.GroupNorm(num_groups=1, num_channels=12, eps=1e-4, affine=True) + + def forward(self, x, y, z): + x = self.gn_0(x) + x = self.gn_1(x) + x = self.gn_2(x) + + y = self.gn_0(y) + y = self.gn_1(y) + y = self.gn_2(y) + + z = self.gn_0(z) + z = self.gn_1(z) + z = self.gn_2(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64) + y = torch.rand(1, 12, 24, 64) + z = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_GroupNorm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_GroupNorm.pt inputshape=[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_GroupNorm_pnnx + b0, b1, b2 = test_nn_GroupNorm_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Hardshrink.py b/tools/pnnx/tests/test_nn_Hardshrink.py new file mode 100644 index 000000000000..c6a42c1e9347 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Hardshrink.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardshrink() + self.act_1 = nn.Hardshrink(lambd=0.3) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Hardshrink.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Hardshrink.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Hardshrink_pnnx + b0, b1, b2, b3 = test_nn_Hardshrink_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Hardsigmoid.py b/tools/pnnx/tests/test_nn_Hardsigmoid.py new file mode 100644 index 000000000000..f6ab4ef327b4 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Hardsigmoid.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardsigmoid() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Hardsigmoid.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Hardsigmoid.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Hardsigmoid_pnnx + b0, b1, b2, b3 = test_nn_Hardsigmoid_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Hardswish.py b/tools/pnnx/tests/test_nn_Hardswish.py new file mode 100644 index 000000000000..73e901f7044b --- /dev/null +++ b/tools/pnnx/tests/test_nn_Hardswish.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardswish() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Hardswish.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Hardswish.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Hardswish_pnnx + b0, b1, b2, b3 = test_nn_Hardswish_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Hardtanh.py b/tools/pnnx/tests/test_nn_Hardtanh.py new file mode 100644 index 000000000000..69886857a0c7 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Hardtanh.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Hardtanh() + self.act_1 = nn.Hardtanh(-0.2, 0.2) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Hardtanh.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Hardtanh.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Hardtanh_pnnx + b0, b1, b2, b3 = test_nn_Hardtanh_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_InstanceNorm1d.py b/tools/pnnx/tests/test_nn_InstanceNorm1d.py new file mode 100644 index 000000000000..1889ba6da3c8 --- /dev/null +++ b/tools/pnnx/tests/test_nn_InstanceNorm1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.in_0 = nn.InstanceNorm1d(num_features=12, affine=True) + self.in_1 = nn.InstanceNorm1d(num_features=12, eps=1e-2, affine=True) + self.in_2 = nn.InstanceNorm1d(num_features=12, eps=1e-4, affine=True, track_running_stats=True) + + def forward(self, x): + x = self.in_0(x) + x = self.in_1(x) + x = self.in_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_InstanceNorm1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_InstanceNorm1d.pt inputshape=[1,12,24]") + + # pnnx inference + import test_nn_InstanceNorm1d_pnnx + b = test_nn_InstanceNorm1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_InstanceNorm2d.py b/tools/pnnx/tests/test_nn_InstanceNorm2d.py new file mode 100644 index 000000000000..079670312215 --- /dev/null +++ b/tools/pnnx/tests/test_nn_InstanceNorm2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.in_0 = nn.InstanceNorm2d(num_features=12, affine=True) + self.in_1 = nn.InstanceNorm2d(num_features=12, eps=1e-2, affine=True) + self.in_2 = nn.InstanceNorm2d(num_features=12, eps=1e-4, affine=True, track_running_stats=True) + + def forward(self, x): + x = self.in_0(x) + x = self.in_1(x) + x = self.in_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_InstanceNorm2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_InstanceNorm2d.pt inputshape=[1,12,24,64]") + + # pnnx inference + import test_nn_InstanceNorm2d_pnnx + b = test_nn_InstanceNorm2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_InstanceNorm3d.py b/tools/pnnx/tests/test_nn_InstanceNorm3d.py new file mode 100644 index 000000000000..c71d1fe5e59f --- /dev/null +++ b/tools/pnnx/tests/test_nn_InstanceNorm3d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.in_0 = nn.InstanceNorm3d(num_features=12, affine=True) + self.in_1 = nn.InstanceNorm3d(num_features=12, eps=1e-2, affine=True) + self.in_2 = nn.InstanceNorm3d(num_features=12, eps=1e-4, affine=True, track_running_stats=True) + + def forward(self, x): + x = self.in_0(x) + x = self.in_1(x) + x = self.in_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24, 32, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_InstanceNorm3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_InstanceNorm3d.pt inputshape=[1,12,24,32,64]") + + # pnnx inference + import test_nn_InstanceNorm3d_pnnx + b = test_nn_InstanceNorm3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_LPPool1d.py b/tools/pnnx/tests/test_nn_LPPool1d.py new file mode 100644 index 000000000000..39ed9210e294 --- /dev/null +++ b/tools/pnnx/tests/test_nn_LPPool1d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.LPPool1d(norm_type=2, kernel_size=3) + self.pool_1 = nn.LPPool1d(norm_type=2, kernel_size=4, stride=2) + self.pool_2 = nn.LPPool1d(norm_type=1, kernel_size=3, stride=1, ceil_mode=False) + self.pool_3 = nn.LPPool1d(norm_type=1, kernel_size=5, stride=1, ceil_mode=True) + self.pool_4 = nn.LPPool1d(norm_type=1.2, kernel_size=3, stride=2, ceil_mode=False) + self.pool_5 = nn.LPPool1d(norm_type=0.5, kernel_size=2, stride=1, ceil_mode=True) + self.pool_6 = nn.LPPool1d(norm_type=0.1, kernel_size=4, stride=1, ceil_mode=False) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x = self.pool_6(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_LPPool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_LPPool1d.pt inputshape=[1,12,128]") + + # pnnx inference + import test_nn_LPPool1d_pnnx + b = test_nn_LPPool1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_LPPool2d.py b/tools/pnnx/tests/test_nn_LPPool2d.py new file mode 100644 index 000000000000..c7e86b149d14 --- /dev/null +++ b/tools/pnnx/tests/test_nn_LPPool2d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.LPPool2d(norm_type=2, kernel_size=3) + self.pool_1 = nn.LPPool2d(norm_type=2, kernel_size=4, stride=2) + self.pool_2 = nn.LPPool2d(norm_type=1, kernel_size=(1,3), stride=1, ceil_mode=False) + self.pool_3 = nn.LPPool2d(norm_type=1, kernel_size=(4,5), stride=(1,2), ceil_mode=True) + self.pool_4 = nn.LPPool2d(norm_type=1.2, kernel_size=(5,3), stride=(2,1), ceil_mode=False) + self.pool_5 = nn.LPPool2d(norm_type=0.5, kernel_size=2, stride=1, ceil_mode=True) + self.pool_6 = nn.LPPool2d(norm_type=0.1, kernel_size=(5,4), stride=1, ceil_mode=False) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x = self.pool_6(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 128, 128) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_LPPool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_LPPool2d.pt inputshape=[1,12,128,128]") + + # pnnx inference + import test_nn_LPPool2d_pnnx + b = test_nn_LPPool2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_LSTM.py b/tools/pnnx/tests/test_nn_LSTM.py new file mode 100644 index 000000000000..36274c8fc387 --- /dev/null +++ b/tools/pnnx/tests/test_nn_LSTM.py @@ -0,0 +1,76 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16) + self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False) + self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.lstm_0_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + + self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True) + self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) + self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.lstm_1_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x0, (h0, c0) = self.lstm_0_0(x) + x1, (h1, c1) = self.lstm_0_1(x0) + x2, (h2, c2) = self.lstm_0_2(x1) + x3, (h3, c3) = self.lstm_0_3(x1, (h2, c2)) + + y0, (h4, c4) = self.lstm_1_0(y) + y1, (h5, c5) = self.lstm_1_1(y0) + y2, (h6, c6) = self.lstm_1_2(y1) + y3, (h7, c7) = self.lstm_1_3(y1, (h6, c6)) + return x2, x3, h0, h1, h2, h3, c0, c1, c2, c3, y2, y3, h4, h5, h6, h7, c4, c5, c6, c7 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(10, 1, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_LSTM.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_LSTM.pt inputshape=[10,1,32],[1,12,25]") + + # pnnx inference + import test_nn_LSTM_pnnx + b = test_nn_LSTM_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_LayerNorm.py b/tools/pnnx/tests/test_nn_LayerNorm.py new file mode 100644 index 000000000000..d5c2ce6f13cc --- /dev/null +++ b/tools/pnnx/tests/test_nn_LayerNorm.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.ln_0 = nn.LayerNorm(64) + self.ln_1 = nn.LayerNorm(normalized_shape=(24,64), eps=1e-2, elementwise_affine=False) + + def forward(self, x, y, z): + x = self.ln_0(x) + x = self.ln_1(x) + + y = self.ln_0(y) + y = self.ln_1(y) + + z = self.ln_0(z) + z = self.ln_1(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 12, 24, 64) + z = torch.rand(1, 12, 16, 24, 64) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_LayerNorm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_LayerNorm.pt inputshape=[1,24,64],[1,12,24,64],[1,12,16,24,64]") + + # pnnx inference + import test_nn_LayerNorm_pnnx + b0, b1, b2 = test_nn_LayerNorm_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_LeakyReLU.py b/tools/pnnx/tests/test_nn_LeakyReLU.py new file mode 100644 index 000000000000..77621712c215 --- /dev/null +++ b/tools/pnnx/tests/test_nn_LeakyReLU.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LeakyReLU() + self.act_1 = nn.LeakyReLU(negative_slope=-0.24) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_LeakyReLU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_LeakyReLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_LeakyReLU_pnnx + b0, b1, b2, b3 = test_nn_LeakyReLU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Linear.py b/tools/pnnx/tests/test_nn_Linear.py new file mode 100644 index 000000000000..4e445d5a13aa --- /dev/null +++ b/tools/pnnx/tests/test_nn_Linear.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.linear_0 = nn.Linear(in_features=64, out_features=16, bias=False) + self.linear_1 = nn.Linear(in_features=16, out_features=3, bias=True) + + def forward(self, x, y, z): + x = self.linear_0(x) + x = self.linear_1(x) + + y = self.linear_0(y) + y = self.linear_1(y) + + z = self.linear_0(z) + z = self.linear_1(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 64) + y = torch.rand(12, 64) + z = torch.rand(1, 3, 12, 64) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Linear.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Linear.pt inputshape=[1,64],[12,64],[1,3,12,64]") + + # pnnx inference + import test_nn_Linear_pnnx + b0, b1, b2 = test_nn_Linear_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_LocalResponseNorm.py b/tools/pnnx/tests/test_nn_LocalResponseNorm.py new file mode 100644 index 000000000000..046fc5e4caed --- /dev/null +++ b/tools/pnnx/tests/test_nn_LocalResponseNorm.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.ln_0 = nn.LocalResponseNorm(3) + self.ln_1 = nn.LocalResponseNorm(size=5, alpha=0.001, beta=0.8, k=0.9) + + def forward(self, x, y, z): + x = self.ln_0(x) + x = self.ln_1(x) + + y = self.ln_0(y) + y = self.ln_1(y) + + z = self.ln_0(z) + z = self.ln_1(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 12, 24, 64) + z = torch.rand(1, 12, 16, 24, 64) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_LocalResponseNorm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_LocalResponseNorm.pt inputshape=[1,24,64],[1,12,24,64],[1,12,16,24,64]") + + # pnnx inference + import test_nn_LocalResponseNorm_pnnx + b0, b1, b2 = test_nn_LocalResponseNorm_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_LogSigmoid.py b/tools/pnnx/tests/test_nn_LogSigmoid.py new file mode 100644 index 000000000000..d01b30c86123 --- /dev/null +++ b/tools/pnnx/tests/test_nn_LogSigmoid.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LogSigmoid() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_LogSigmoid.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_LogSigmoid.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_LogSigmoid_pnnx + b0, b1, b2, b3 = test_nn_LogSigmoid_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_LogSoftmax.py b/tools/pnnx/tests/test_nn_LogSoftmax.py new file mode 100644 index 000000000000..bff38ed1603b --- /dev/null +++ b/tools/pnnx/tests/test_nn_LogSoftmax.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LogSoftmax(dim=1) + self.act_1 = nn.LogSoftmax(dim=1) + self.act_2 = nn.LogSoftmax(dim=0) + self.act_3 = nn.LogSoftmax(dim=2) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_1(y) + z = self.act_2(z) + w = self.act_3(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_LogSoftmax.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_LogSoftmax.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_LogSoftmax_pnnx + b0, b1, b2, b3 = test_nn_LogSoftmax_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_MaxPool1d.py b/tools/pnnx/tests/test_nn_MaxPool1d.py new file mode 100644 index 000000000000..8dbcd3d858c6 --- /dev/null +++ b/tools/pnnx/tests/test_nn_MaxPool1d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.MaxPool1d(kernel_size=3) + self.pool_1 = nn.MaxPool1d(kernel_size=4, stride=2, padding=2, dilation=1) + self.pool_2 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) + self.pool_3 = nn.MaxPool1d(kernel_size=5, stride=2, padding=2, dilation=1, return_indices=False, ceil_mode=True) + self.pool_4 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=2, return_indices=False, ceil_mode=False) + self.pool_5 = nn.MaxPool1d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + self.pool_6 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x, indices = self.pool_6(x) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_MaxPool1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_MaxPool1d.pt inputshape=[1,12,64]") + + # pnnx inference + import test_nn_MaxPool1d_pnnx + b0, b1 = test_nn_MaxPool1d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_MaxPool2d.py b/tools/pnnx/tests/test_nn_MaxPool2d.py new file mode 100644 index 000000000000..497f05a6aeb4 --- /dev/null +++ b/tools/pnnx/tests/test_nn_MaxPool2d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.MaxPool2d(kernel_size=3) + self.pool_1 = nn.MaxPool2d(kernel_size=4, stride=2, padding=2, dilation=1) + self.pool_2 = nn.MaxPool2d(kernel_size=(1,3), stride=1, padding=(0,1), dilation=1, return_indices=False, ceil_mode=False) + self.pool_3 = nn.MaxPool2d(kernel_size=(4,5), stride=(1,2), padding=(1,2), dilation=1, return_indices=False, ceil_mode=True) + self.pool_4 = nn.MaxPool2d(kernel_size=(2,3), stride=1, padding=1, dilation=(1,2), return_indices=False, ceil_mode=False) + self.pool_5 = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + self.pool_6 = nn.MaxPool2d(kernel_size=(5,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x, indices = self.pool_6(x) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64, 64) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_MaxPool2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_MaxPool2d.pt inputshape=[1,12,64,64]") + + # pnnx inference + import test_nn_MaxPool2d_pnnx + b0, b1 = test_nn_MaxPool2d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_MaxPool3d.py b/tools/pnnx/tests/test_nn_MaxPool3d.py new file mode 100644 index 000000000000..22918594e47a --- /dev/null +++ b/tools/pnnx/tests/test_nn_MaxPool3d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pool_0 = nn.MaxPool3d(kernel_size=3) + self.pool_1 = nn.MaxPool3d(kernel_size=4, stride=2, padding=2, dilation=1) + self.pool_2 = nn.MaxPool3d(kernel_size=(1,2,3), stride=1, padding=(0,0,1), dilation=1, return_indices=False, ceil_mode=False) + self.pool_3 = nn.MaxPool3d(kernel_size=(3,4,5), stride=(1,2,2), padding=(1,2,2), dilation=1, return_indices=False, ceil_mode=True) + self.pool_4 = nn.MaxPool3d(kernel_size=(2,3,3), stride=1, padding=1, dilation=(1,2,2), return_indices=False, ceil_mode=False) + self.pool_5 = nn.MaxPool3d(kernel_size=2, stride=1, padding=0, dilation=1, return_indices=False, ceil_mode=True) + self.pool_6 = nn.MaxPool3d(kernel_size=(5,4,4), stride=1, padding=2, dilation=1, return_indices=True, ceil_mode=False) + + def forward(self, x): + x = self.pool_0(x) + x = self.pool_1(x) + x = self.pool_2(x) + x = self.pool_3(x) + x = self.pool_4(x) + x = self.pool_5(x) + x, indices = self.pool_6(x) + return x, indices + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64, 64, 64) + + a0, a1 = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_MaxPool3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_MaxPool3d.pt inputshape=[1,12,64,64,64]") + + # pnnx inference + import test_nn_MaxPool3d_pnnx + b0, b1 = test_nn_MaxPool3d_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Mish.py b/tools/pnnx/tests/test_nn_Mish.py new file mode 100644 index 000000000000..e06ea91d94ad --- /dev/null +++ b/tools/pnnx/tests/test_nn_Mish.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Mish() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Mish.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Mish.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Mish_pnnx + b0, b1, b2, b3 = test_nn_Mish_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_MultiheadAttention.py b/tools/pnnx/tests/test_nn_MultiheadAttention.py new file mode 100644 index 000000000000..afb553ddb9f3 --- /dev/null +++ b/tools/pnnx/tests/test_nn_MultiheadAttention.py @@ -0,0 +1,81 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.attention_0_0 = nn.MultiheadAttention(embed_dim=64, num_heads=4) + self.attention_0_1 = nn.MultiheadAttention(embed_dim=64, num_heads=8, bias=False, add_bias_kv=False, add_zero_attn=False) + self.attention_0_2 = nn.MultiheadAttention(embed_dim=64, num_heads=16, bias=True, add_bias_kv=True, add_zero_attn=True) + + if torch.__version__ >= '1.9': + self.attention_1_0 = nn.MultiheadAttention(embed_dim=40, num_heads=4, batch_first=True) + self.attention_1_1 = nn.MultiheadAttention(embed_dim=40, num_heads=8, bias=False, add_bias_kv=False, add_zero_attn=False, batch_first=True) + self.attention_1_2 = nn.MultiheadAttention(embed_dim=40, num_heads=10, bias=True, add_bias_kv=True, add_zero_attn=True, batch_first=True) + + def forward(self, xq, xk, xv, yq, yk, yv): + x0, x0w = self.attention_0_0(xq, xk, xv) + x1, x1w = self.attention_0_1(xq, xk, xv) + x2, x2w = self.attention_0_2(xq, xk, xv) + + if torch.__version__ < '1.9': + return x0, x0w, x1, x1w, x2, x2w + + y0, y0w = self.attention_1_0(yq, yk, yv) + y1, y1w = self.attention_1_1(yq, yk, yv) + y2, y2w = self.attention_1_2(yq, yk, yv) + + return x0, x0w, x1, x1w, x2, x2w, y0, y0w, y1, y1w, y2, y2w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + xq = torch.rand(20, 1, 64) + xk = torch.rand(20, 1, 64) + xv = torch.rand(20, 1, 64) + yq = torch.rand(1, 15, 40) + yk = torch.rand(1, 24, 40) + yv = torch.rand(1, 24, 40) + + a = net(xq, xk, xv, yq, yk, yv) + + # export torchscript + mod = torch.jit.trace(net, (xq, xk, xv, yq, yk, yv)) + mod.save("test_nn_MultiheadAttention.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_MultiheadAttention.pt inputshape=[20,1,64],[20,1,64],[20,1,64],[1,15,40],[1,24,40],[1,24,40]") + + # pnnx inference + import test_nn_MultiheadAttention_pnnx + b = test_nn_MultiheadAttention_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_PReLU.py b/tools/pnnx/tests/test_nn_PReLU.py new file mode 100644 index 000000000000..847521194322 --- /dev/null +++ b/tools/pnnx/tests/test_nn_PReLU.py @@ -0,0 +1,70 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.prelu_0 = nn.PReLU(num_parameters=12) + self.prelu_1 = nn.PReLU(num_parameters=1, init=0.12) + + def forward(self, x, y, z, w): + x = self.prelu_0(x) + x = self.prelu_1(x) + + y = self.prelu_0(y) + y = self.prelu_1(y) + + z = self.prelu_0(z) + z = self.prelu_1(z) + + w = self.prelu_0(w) + w = self.prelu_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_PReLU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_PReLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_PReLU_pnnx + b0, b1, b2, b3 = test_nn_PReLU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_PixelShuffle.py b/tools/pnnx/tests/test_nn_PixelShuffle.py new file mode 100644 index 000000000000..e53ba5234d01 --- /dev/null +++ b/tools/pnnx/tests/test_nn_PixelShuffle.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.up_0 = nn.PixelShuffle(4) + self.up_1 = nn.PixelShuffle(2) + + def forward(self, x, y): + x = self.up_0(x) + x = self.up_1(x) + + y = self.up_0(y) + y = self.up_1(y) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 128, 6, 8) + y = torch.rand(1, 12, 192, 7, 9) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_PixelShuffle.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_PixelShuffle.pt inputshape=[1,128,6,8],[1,12,192,7,9]") + + # pnnx inference + import test_nn_PixelShuffle_pnnx + b0, b1 = test_nn_PixelShuffle_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_PixelUnshuffle.py b/tools/pnnx/tests/test_nn_PixelUnshuffle.py new file mode 100644 index 000000000000..b21a98a87db4 --- /dev/null +++ b/tools/pnnx/tests/test_nn_PixelUnshuffle.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.down_0 = nn.PixelUnshuffle(2) + self.down_1 = nn.PixelUnshuffle(4) + + def forward(self, x, y): + x = self.down_0(x) + x = self.down_1(x) + + y = self.down_0(y) + y = self.down_1(y) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 128, 128) + y = torch.rand(1, 12, 4, 192, 192) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_PixelUnshuffle.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_PixelUnshuffle.pt inputshape=[1,3,128,128],[1,12,4,192,192]") + + # pnnx inference + import test_nn_PixelUnshuffle_pnnx + b0, b1 = test_nn_PixelUnshuffle_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_RNN.py b/tools/pnnx/tests/test_nn_RNN.py new file mode 100644 index 000000000000..d0188efa7f59 --- /dev/null +++ b/tools/pnnx/tests/test_nn_RNN.py @@ -0,0 +1,76 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.rnn_0_0 = nn.RNN(input_size=32, hidden_size=16) + self.rnn_0_1 = nn.RNN(input_size=16, hidden_size=16, num_layers=3, nonlinearity='tanh', bias=False) + self.rnn_0_2 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='relu', bias=True, bidirectional=True) + self.rnn_0_3 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, bidirectional=True) + + self.rnn_1_0 = nn.RNN(input_size=25, hidden_size=16, batch_first=True) + self.rnn_1_1 = nn.RNN(input_size=16, hidden_size=16, num_layers=3, nonlinearity='tanh', bias=False, batch_first=True) + self.rnn_1_2 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='relu', bias=True, batch_first=True, bidirectional=True) + self.rnn_1_3 = nn.RNN(input_size=16, hidden_size=16, num_layers=4, nonlinearity='tanh', bias=True, batch_first=True, bidirectional=True) + + def forward(self, x, y): + x0, h0 = self.rnn_0_0(x) + x1, h1 = self.rnn_0_1(x0) + x2, h2 = self.rnn_0_2(x1) + x3, h3 = self.rnn_0_3(x1, h2) + + y0, h4 = self.rnn_1_0(y) + y1, h5 = self.rnn_1_1(y0) + y2, h6 = self.rnn_1_2(y1) + y3, h7 = self.rnn_1_3(y1, h6) + return x2, x3, h0, h1, h2, h3, y2, y3, h4, h5, h6, h7 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(10, 1, 32) + y = torch.rand(1, 12, 25) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_RNN.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_RNN.pt inputshape=[10,1,32],[1,12,25]") + + # pnnx inference + import test_nn_RNN_pnnx + b = test_nn_RNN_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_RReLU.py b/tools/pnnx/tests/test_nn_RReLU.py new file mode 100644 index 000000000000..f8929054348a --- /dev/null +++ b/tools/pnnx/tests/test_nn_RReLU.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.RReLU() + self.act_1 = nn.RReLU(lower=0.1, upper=0.42) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_RReLU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_RReLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_RReLU_pnnx + b0, b1, b2, b3 = test_nn_RReLU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ReLU.py b/tools/pnnx/tests/test_nn_ReLU.py new file mode 100644 index 000000000000..5dddab517f4d --- /dev/null +++ b/tools/pnnx/tests/test_nn_ReLU.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.ReLU() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_ReLU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ReLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_ReLU_pnnx + b0, b1, b2, b3 = test_nn_ReLU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ReLU6.py b/tools/pnnx/tests/test_nn_ReLU6.py new file mode 100644 index 000000000000..92e08c1a27ed --- /dev/null +++ b/tools/pnnx/tests/test_nn_ReLU6.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.ReLU6() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_ReLU6.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ReLU6.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_ReLU6_pnnx + b0, b1, b2, b3 = test_nn_ReLU6_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ReflectionPad1d.py b/tools/pnnx/tests/test_nn_ReflectionPad1d.py new file mode 100644 index 000000000000..b2c8c823c760 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ReflectionPad1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReflectionPad1d(2) + self.pad_1 = nn.ReflectionPad1d(padding=(3,4)) + self.pad_2 = nn.ReflectionPad1d(padding=(1,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReflectionPad1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ReflectionPad1d.pt inputshape=[1,12,13]") + + # pnnx inference + import test_nn_ReflectionPad1d_pnnx + b = test_nn_ReflectionPad1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ReflectionPad2d.py b/tools/pnnx/tests/test_nn_ReflectionPad2d.py new file mode 100644 index 000000000000..59984cf23726 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ReflectionPad2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReflectionPad2d(2) + self.pad_1 = nn.ReflectionPad2d(padding=(3,4,5,6)) + self.pad_2 = nn.ReflectionPad2d(padding=(1,0,2,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReflectionPad2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ReflectionPad2d.pt inputshape=[1,12,13,13]") + + # pnnx inference + import test_nn_ReflectionPad2d_pnnx + b = test_nn_ReflectionPad2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ReplicationPad1d.py b/tools/pnnx/tests/test_nn_ReplicationPad1d.py new file mode 100644 index 000000000000..bd084aae7377 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ReplicationPad1d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReplicationPad1d(2) + self.pad_1 = nn.ReplicationPad1d(padding=(3,4)) + self.pad_2 = nn.ReplicationPad1d(padding=(1,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReplicationPad1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ReplicationPad1d.pt inputshape=[1,12,13]") + + # pnnx inference + import test_nn_ReplicationPad1d_pnnx + b = test_nn_ReplicationPad1d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ReplicationPad2d.py b/tools/pnnx/tests/test_nn_ReplicationPad2d.py new file mode 100644 index 000000000000..8a806f3580e8 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ReplicationPad2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReplicationPad2d(2) + self.pad_1 = nn.ReplicationPad2d(padding=(3,4,5,6)) + self.pad_2 = nn.ReplicationPad2d(padding=(1,0,2,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReplicationPad2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ReplicationPad2d.pt inputshape=[1,12,13,13]") + + # pnnx inference + import test_nn_ReplicationPad2d_pnnx + b = test_nn_ReplicationPad2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ReplicationPad3d.py b/tools/pnnx/tests/test_nn_ReplicationPad3d.py new file mode 100644 index 000000000000..988d3a582470 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ReplicationPad3d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ReplicationPad3d(2) + self.pad_1 = nn.ReplicationPad3d(padding=(1,2,3,4,5,6)) + self.pad_2 = nn.ReplicationPad3d(padding=(1,0,2,0,0,3)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ReplicationPad3d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ReplicationPad3d.pt inputshape=[1,12,13,13,13]") + + # pnnx inference + import test_nn_ReplicationPad3d_pnnx + b = test_nn_ReplicationPad3d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_SELU.py b/tools/pnnx/tests/test_nn_SELU.py new file mode 100644 index 000000000000..273b0e588b45 --- /dev/null +++ b/tools/pnnx/tests/test_nn_SELU.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.SELU() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_SELU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_SELU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_SELU_pnnx + b0, b1, b2, b3 = test_nn_SELU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_SiLU.py b/tools/pnnx/tests/test_nn_SiLU.py new file mode 100644 index 000000000000..c1a98711b67d --- /dev/null +++ b/tools/pnnx/tests/test_nn_SiLU.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.SiLU() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_SiLU.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_SiLU.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_SiLU_pnnx + b0, b1, b2, b3 = test_nn_SiLU_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Sigmoid.py b/tools/pnnx/tests/test_nn_Sigmoid.py new file mode 100644 index 000000000000..28c922c44b7d --- /dev/null +++ b/tools/pnnx/tests/test_nn_Sigmoid.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Sigmoid() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Sigmoid.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Sigmoid.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Sigmoid_pnnx + b0, b1, b2, b3 = test_nn_Sigmoid_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Softmax.py b/tools/pnnx/tests/test_nn_Softmax.py new file mode 100644 index 000000000000..475385d259c3 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Softmax.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softmax(dim=1) + self.act_1 = nn.Softmax(dim=1) + self.act_2 = nn.Softmax(dim=0) + self.act_3 = nn.Softmax(dim=2) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_1(y) + z = self.act_2(z) + w = self.act_3(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Softmax.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Softmax.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softmax_pnnx + b0, b1, b2, b3 = test_nn_Softmax_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Softmin.py b/tools/pnnx/tests/test_nn_Softmin.py new file mode 100644 index 000000000000..8560aef1dd57 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Softmin.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softmin(dim=1) + self.act_1 = nn.Softmin(dim=1) + self.act_2 = nn.Softmin(dim=0) + self.act_3 = nn.Softmin(dim=2) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_1(y) + z = self.act_2(z) + w = self.act_3(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Softmin.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Softmin.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softmin_pnnx + b0, b1, b2, b3 = test_nn_Softmin_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Softplus.py b/tools/pnnx/tests/test_nn_Softplus.py new file mode 100644 index 000000000000..b95f6826bbb3 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Softplus.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softplus() + self.act_1 = nn.Softplus(beta=0.7, threshold=15) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Softplus.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Softplus.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softplus_pnnx + b0, b1, b2, b3 = test_nn_Softplus_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Softshrink.py b/tools/pnnx/tests/test_nn_Softshrink.py new file mode 100644 index 000000000000..db0ad788e98a --- /dev/null +++ b/tools/pnnx/tests/test_nn_Softshrink.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softshrink() + self.act_1 = nn.Softshrink(lambd=1.3) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Softshrink.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Softshrink.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softshrink_pnnx + b0, b1, b2, b3 = test_nn_Softshrink_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Softsign.py b/tools/pnnx/tests/test_nn_Softsign.py new file mode 100644 index 000000000000..088933895ff0 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Softsign.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softsign() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Softsign.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Softsign.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softsign_pnnx + b0, b1, b2, b3 = test_nn_Softsign_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Tanh.py b/tools/pnnx/tests/test_nn_Tanh.py new file mode 100644 index 000000000000..f9ec1babbebe --- /dev/null +++ b/tools/pnnx/tests/test_nn_Tanh.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Tanh() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Tanh.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Tanh.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Tanh_pnnx + b0, b1, b2, b3 = test_nn_Tanh_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Tanhshrink.py b/tools/pnnx/tests/test_nn_Tanhshrink.py new file mode 100644 index 000000000000..4d1611b9adbc --- /dev/null +++ b/tools/pnnx/tests/test_nn_Tanhshrink.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Tanhshrink() + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Tanhshrink.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Tanhshrink.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Tanhshrink_pnnx + b0, b1, b2, b3 = test_nn_Tanhshrink_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Threshold.py b/tools/pnnx/tests/test_nn_Threshold.py new file mode 100644 index 000000000000..329b3c3440a0 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Threshold.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Threshold(0.1, 20) + self.act_1 = nn.Threshold(0.3, 0.4) + + def forward(self, x, y, z, w): + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a0, a1, a2, a3 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_nn_Threshold.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Threshold.pt inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Threshold_pnnx + b0, b1, b2, b3 = test_nn_Threshold_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Upsample.py b/tools/pnnx/tests/test_nn_Upsample.py new file mode 100644 index 000000000000..ee97980431f2 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Upsample.py @@ -0,0 +1,137 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.up_1d_0_0 = nn.Upsample(size=16) + self.up_1d_0_1 = nn.Upsample(scale_factor=2, mode='nearest') + self.up_1d_0_2 = nn.Upsample(size=(20), mode='nearest') + self.up_1d_0_3 = nn.Upsample(scale_factor=(4), mode='nearest') + self.up_1d_1_0 = nn.Upsample(size=16, mode='linear') + self.up_1d_1_1 = nn.Upsample(scale_factor=2, mode='linear') + self.up_1d_1_2 = nn.Upsample(size=(24), mode='linear', align_corners=True) + self.up_1d_1_3 = nn.Upsample(scale_factor=(3), mode='linear', align_corners=True) + + self.up_2d_0_0 = nn.Upsample(size=16) + self.up_2d_0_1 = nn.Upsample(scale_factor=2, mode='nearest') + self.up_2d_0_2 = nn.Upsample(size=(20,20), mode='nearest') + self.up_2d_0_3 = nn.Upsample(scale_factor=(4,4), mode='nearest') + self.up_2d_0_4 = nn.Upsample(size=(16,24), mode='nearest') + self.up_2d_0_5 = nn.Upsample(scale_factor=(2,3), mode='nearest') + self.up_2d_1_0 = nn.Upsample(size=16, mode='bilinear') + self.up_2d_1_1 = nn.Upsample(scale_factor=2, mode='bilinear') + self.up_2d_1_2 = nn.Upsample(size=(20,20), mode='bilinear', align_corners=False) + self.up_2d_1_3 = nn.Upsample(scale_factor=(4,4), mode='bilinear', align_corners=False) + self.up_2d_1_4 = nn.Upsample(size=(16,24), mode='bilinear', align_corners=True) + self.up_2d_1_5 = nn.Upsample(scale_factor=(2,3), mode='bilinear', align_corners=True) + self.up_2d_2_0 = nn.Upsample(size=16, mode='bicubic') + self.up_2d_2_1 = nn.Upsample(scale_factor=2, mode='bicubic') + self.up_2d_2_2 = nn.Upsample(size=(20,20), mode='bicubic', align_corners=False) + self.up_2d_2_3 = nn.Upsample(scale_factor=(4,4), mode='bicubic', align_corners=False) + self.up_2d_2_4 = nn.Upsample(size=(16,24), mode='bicubic', align_corners=True) + self.up_2d_2_5 = nn.Upsample(scale_factor=(2,3), mode='bicubic', align_corners=True) + + self.up_3d_0_0 = nn.Upsample(size=16) + self.up_3d_0_1 = nn.Upsample(scale_factor=2, mode='nearest') + self.up_3d_0_2 = nn.Upsample(size=(20,20,20), mode='nearest') + self.up_3d_0_3 = nn.Upsample(scale_factor=(4,4,4), mode='nearest') + self.up_3d_0_4 = nn.Upsample(size=(16,24,20), mode='nearest') + self.up_3d_0_5 = nn.Upsample(scale_factor=(2,3,4), mode='nearest') + self.up_3d_1_0 = nn.Upsample(size=16, mode='trilinear') + self.up_3d_1_1 = nn.Upsample(scale_factor=2, mode='trilinear') + self.up_3d_1_2 = nn.Upsample(size=(20,20,20), mode='trilinear', align_corners=False) + self.up_3d_1_3 = nn.Upsample(scale_factor=(4,4,4), mode='trilinear', align_corners=False) + self.up_3d_1_4 = nn.Upsample(size=(16,24,20), mode='trilinear', align_corners=True) + self.up_3d_1_5 = nn.Upsample(scale_factor=(2,3,4), mode='trilinear', align_corners=True) + + def forward(self, x, y, z): + x = self.up_1d_0_0(x) + x = self.up_1d_0_1(x) + x = self.up_1d_0_2(x) + x = self.up_1d_0_3(x) + x = self.up_1d_1_0(x) + x = self.up_1d_1_1(x) + x = self.up_1d_1_2(x) + x = self.up_1d_1_3(x) + + y = self.up_2d_0_0(y) + y = self.up_2d_0_1(y) + y = self.up_2d_0_2(y) + y = self.up_2d_0_3(y) + y = self.up_2d_0_4(y) + y = self.up_2d_0_5(y) + y = self.up_2d_1_0(y) + y = self.up_2d_1_1(y) + y = self.up_2d_1_2(y) + y = self.up_2d_1_3(y) + y = self.up_2d_1_4(y) + y = self.up_2d_1_5(y) + y = self.up_2d_2_0(y) + y = self.up_2d_2_1(y) + y = self.up_2d_2_2(y) + y = self.up_2d_2_3(y) + y = self.up_2d_2_4(y) + y = self.up_2d_2_5(y) + + z = self.up_3d_0_0(z) + z = self.up_3d_0_1(z) + z = self.up_3d_0_2(z) + z = self.up_3d_0_3(z) + z = self.up_3d_0_4(z) + z = self.up_3d_0_5(z) + z = self.up_3d_1_0(z) + z = self.up_3d_1_1(z) + z = self.up_3d_1_2(z) + z = self.up_3d_1_3(z) + z = self.up_3d_1_4(z) + z = self.up_3d_1_5(z) + + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32) + y = torch.rand(1, 3, 32, 32) + z = torch.rand(1, 3, 32, 32, 32) + + a0, a1, a2 = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_nn_Upsample.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Upsample.pt inputshape=[1,3,32],[1,3,32,32],[1,3,32,32,32]") + + # pnnx inference + import test_nn_Upsample_pnnx + b0, b1, b2 = test_nn_Upsample_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_UpsamplingBilinear2d.py b/tools/pnnx/tests/test_nn_UpsamplingBilinear2d.py new file mode 100644 index 000000000000..600fa61e0e91 --- /dev/null +++ b/tools/pnnx/tests/test_nn_UpsamplingBilinear2d.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.up_0 = nn.UpsamplingBilinear2d(size=16) + self.up_1 = nn.UpsamplingBilinear2d(scale_factor=2) + self.up_2 = nn.UpsamplingBilinear2d(size=(20,20)) + self.up_3 = nn.UpsamplingBilinear2d(scale_factor=(4,4)) + self.up_4 = nn.UpsamplingBilinear2d(size=(16,24)) + self.up_5 = nn.UpsamplingBilinear2d(scale_factor=(2,3)) + + def forward(self, x): + x = self.up_0(x) + x = self.up_1(x) + x = self.up_2(x) + x = self.up_3(x) + x = self.up_4(x) + x = self.up_5(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32, 32) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_UpsamplingBilinear2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_UpsamplingBilinear2d.pt inputshape=[1,3,32,32]") + + # pnnx inference + import test_nn_UpsamplingBilinear2d_pnnx + b = test_nn_UpsamplingBilinear2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_UpsamplingNearest2d.py b/tools/pnnx/tests/test_nn_UpsamplingNearest2d.py new file mode 100644 index 000000000000..c9975e073110 --- /dev/null +++ b/tools/pnnx/tests/test_nn_UpsamplingNearest2d.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.up_0 = nn.UpsamplingNearest2d(size=16) + self.up_1 = nn.UpsamplingNearest2d(scale_factor=2) + self.up_2 = nn.UpsamplingNearest2d(size=(20,20)) + self.up_3 = nn.UpsamplingNearest2d(scale_factor=(4,4)) + self.up_4 = nn.UpsamplingNearest2d(size=(16,24)) + self.up_5 = nn.UpsamplingNearest2d(scale_factor=(2,3)) + + def forward(self, x): + x = self.up_0(x) + x = self.up_1(x) + x = self.up_2(x) + x = self.up_3(x) + x = self.up_4(x) + x = self.up_5(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 32, 32) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_UpsamplingNearest2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_UpsamplingNearest2d.pt inputshape=[1,3,32,32]") + + # pnnx inference + import test_nn_UpsamplingNearest2d_pnnx + b = test_nn_UpsamplingNearest2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_ZeroPad2d.py b/tools/pnnx/tests/test_nn_ZeroPad2d.py new file mode 100644 index 000000000000..002102a212b8 --- /dev/null +++ b/tools/pnnx/tests/test_nn_ZeroPad2d.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.pad_0 = nn.ZeroPad2d(2) + self.pad_1 = nn.ZeroPad2d(padding=(3,4,5,6)) + self.pad_2 = nn.ZeroPad2d(padding=(1,0,2,0)) + + def forward(self, x): + x = self.pad_0(x) + x = self.pad_1(x) + x = self.pad_2(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 13, 13) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_ZeroPad2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_ZeroPad2d.pt inputshape=[1,12,13,13]") + + # pnnx inference + import test_nn_ZeroPad2d_pnnx + b = test_nn_ZeroPad2d_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_conv2d_batchnorm2d.py b/tools/pnnx/tests/test_pnnx_fuse_conv2d_batchnorm2d.py new file mode 100644 index 000000000000..5082e3feff5b --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_conv2d_batchnorm2d.py @@ -0,0 +1,93 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.conv_0 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3) + self.bn_0 = nn.BatchNorm2d(num_features=16) + self.conv_1 = nn.Conv2d(in_channels=16, out_channels=20, kernel_size=(2,4), stride=(2,1), padding=2, dilation=1) + self.bn_1 = nn.BatchNorm2d(num_features=20) + self.conv_2 = nn.Conv2d(in_channels=20, out_channels=24, kernel_size=(1,3), stride=1, padding=(2,4), dilation=1, groups=1, bias=False) + self.bn_2 = nn.BatchNorm2d(num_features=24) + if torch.__version__ < '1.9': + self.conv_3 = nn.Conv2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=1, padding=0, dilation=1, groups=4, bias=True) + else: + self.conv_3 = nn.Conv2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=1, padding='valid', dilation=1, groups=4, bias=True) + self.bn_3 = nn.BatchNorm2d(num_features=28) + if torch.__version__ < '1.9': + self.conv_4 = nn.Conv2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=(1,2), groups=2, bias=False, padding_mode='zeros') + else: + self.conv_4 = nn.Conv2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2), groups=2, bias=False, padding_mode='zeros') + self.bn_4 = nn.BatchNorm2d(num_features=32) + self.conv_5 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.bn_5 = nn.BatchNorm2d(num_features=32) + self.conv_6 = nn.Conv2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + self.bn_6 = nn.BatchNorm2d(num_features=28) + #self.conv_7 = nn.Conv2d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') + #self.bn_7 = nn.BatchNorm2d(num_features=24) + + def forward(self, x): + x = self.conv_0(x) + x = self.bn_0(x) + x = self.conv_1(x) + x = self.bn_1(x) + x = self.conv_2(x) + x = self.bn_2(x) + x = self.conv_3(x) + x = self.bn_3(x) + x = self.conv_4(x) + x = self.bn_4(x) + x = self.conv_5(x) + x = self.bn_5(x) + x = self.conv_6(x) + x = self.bn_6(x) + #x = self.conv_7(x) + #x = self.bn_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 64, 64) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_fuse_conv2d_batchnorm2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_conv2d_batchnorm2d.pt inputshape=[1,12,64,64]") + + # pnnx inference + import test_pnnx_fuse_conv2d_batchnorm2d_pnnx + b = test_pnnx_fuse_conv2d_batchnorm2d_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_convtranspose2d_batchnorm2d.py b/tools/pnnx/tests/test_pnnx_fuse_convtranspose2d_batchnorm2d.py new file mode 100644 index 000000000000..906a8ef72c70 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_convtranspose2d_batchnorm2d.py @@ -0,0 +1,87 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.deconv_0 = nn.ConvTranspose2d(in_channels=12, out_channels=16, kernel_size=3) + self.bn_0 = nn.BatchNorm2d(num_features=16) + self.deconv_1 = nn.ConvTranspose2d(in_channels=16, out_channels=20, kernel_size=(2,4), stride=(2,1), padding=2, output_padding=0) + self.bn_1 = nn.BatchNorm2d(num_features=20) + self.deconv_2 = nn.ConvTranspose2d(in_channels=20, out_channels=24, kernel_size=(1,3), stride=1, padding=(2,4), output_padding=(0,0), dilation=1, groups=1, bias=False) + self.bn_2 = nn.BatchNorm2d(num_features=24, eps=1e-1, affine=False) + self.deconv_3 = nn.ConvTranspose2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=2, padding=0, output_padding=(0,1), dilation=1, groups=4, bias=True) + self.bn_3 = nn.BatchNorm2d(num_features=28, eps=1e-1, affine=False) + self.deconv_4 = nn.ConvTranspose2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding=1, output_padding=0, dilation=(1,2), groups=2, bias=False) + self.bn_4 = nn.BatchNorm2d(num_features=32) + self.deconv_5 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, output_padding=1, dilation=1, groups=32, bias=True) + self.bn_5 = nn.BatchNorm2d(num_features=32) + self.deconv_6 = nn.ConvTranspose2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) + self.bn_6 = nn.BatchNorm2d(num_features=28, affine=True) + self.deconv_7 = nn.ConvTranspose2d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), output_padding=(1,0), dilation=2, groups=1, bias=True) + self.bn_7 = nn.BatchNorm2d(num_features=24, affine=True) + + def forward(self, x): + x = self.deconv_0(x) + x = self.bn_0(x) + x = self.deconv_1(x) + x = self.bn_1(x) + x = self.deconv_2(x) + x = self.bn_2(x) + x = self.deconv_3(x) + x = self.bn_3(x) + x = self.deconv_4(x) + x = self.bn_4(x) + x = self.deconv_5(x) + x = self.bn_5(x) + x = self.deconv_6(x) + x = self.bn_6(x) + x = self.deconv_7(x) + x = self.bn_7(x) + + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 10, 10) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_fuse_convtranspose2d_batchnorm2d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_convtranspose2d_batchnorm2d.pt inputshape=[1,12,10,10]") + + # pnnx inference + import test_pnnx_fuse_convtranspose2d_batchnorm2d_pnnx + b = test_pnnx_fuse_convtranspose2d_batchnorm2d_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fuse_linear_batchnorm1d.py b/tools/pnnx/tests/test_pnnx_fuse_linear_batchnorm1d.py new file mode 100644 index 000000000000..a96baed4900b --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_linear_batchnorm1d.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.linear_0 = nn.Linear(in_features=64, out_features=16, bias=False) + self.bn_0 = nn.BatchNorm1d(num_features=16) + self.linear_1 = nn.Linear(in_features=16, out_features=3, bias=True) + self.bn_1 = nn.BatchNorm1d(num_features=3) + + def forward(self, x, y): + x = self.linear_0(x) + x = self.bn_0(x) + x = self.linear_1(x) + x = self.bn_1(x) + + y = self.linear_0(y) + y = self.bn_0(y) + y = self.linear_1(y) + y = self.bn_1(y) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 64) + y = torch.rand(12, 64) + + a0, a1 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_pnnx_fuse_linear_batchnorm1d.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_linear_batchnorm1d.pt inputshape=[1,64],[12,64]") + + # pnnx inference + import test_pnnx_fuse_linear_batchnorm1d_pnnx + b0, b1 = test_pnnx_fuse_linear_batchnorm1d_pnnx.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_quantization_shufflenet_v2_x1_0.py b/tools/pnnx/tests/test_quantization_shufflenet_v2_x1_0.py new file mode 100644 index 000000000000..f223e7e0a270 --- /dev/null +++ b/tools/pnnx/tests/test_quantization_shufflenet_v2_x1_0.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.quantization.shufflenet_v2_x1_0(quantize=True) + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_quantization_shufflenet_v2_x1_0.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_quantization_shufflenet_v2_x1_0.pt inputshape=[1,3,224,224]") + + # pnnx inference + import test_quantization_shufflenet_v2_x1_0_pnnx + b = test_quantization_shufflenet_v2_x1_0_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_resnet18.py b/tools/pnnx/tests/test_resnet18.py new file mode 100644 index 000000000000..2dd9dcdd6b19 --- /dev/null +++ b/tools/pnnx/tests/test_resnet18.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.resnet18() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_resnet18.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_resnet18.pt inputshape=[1,3,224,224]") + + # pnnx inference + import test_resnet18_pnnx + b = test_resnet18_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_shufflenet_v2_x1_0.py b/tools/pnnx/tests/test_shufflenet_v2_x1_0.py new file mode 100644 index 000000000000..cbea072b54bc --- /dev/null +++ b/tools/pnnx/tests/test_shufflenet_v2_x1_0.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.shufflenet_v2_x1_0() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_shufflenet_v2_x1_0.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_shufflenet_v2_x1_0.pt inputshape=[1,3,224,224]") + + # pnnx inference + import test_shufflenet_v2_x1_0_pnnx + b = test_shufflenet_v2_x1_0_pnnx.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_squeezenet1_1.py b/tools/pnnx/tests/test_squeezenet1_1.py new file mode 100644 index 000000000000..b38790745f1a --- /dev/null +++ b/tools/pnnx/tests/test_squeezenet1_1.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torchvision.models as models + +def test(): + net = models.squeezenet1_1() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 224, 224) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_squeezenet1_1.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_squeezenet1_1.pt inputshape=[1,3,224,224]") + + # pnnx inference + import test_squeezenet1_1_pnnx + b = test_squeezenet1_1_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_cat.py b/tools/pnnx/tests/test_torch_cat.py new file mode 100644 index 000000000000..98e6f8c16c04 --- /dev/null +++ b/tools/pnnx/tests/test_torch_cat.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + out0 = torch.cat((x, y), dim=1) + out1 = torch.cat((z, w), dim=3) + out2 = torch.cat((w, w), dim=2) + return out0, out1, out2 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 2, 16) + z = torch.rand(1, 5, 9, 11) + w = torch.rand(1, 5, 9, 3) + + a0, a1, a2 = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_cat.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_cat.pt inputshape=[1,3,16],[1,2,16],[1,5,9,11],[1,5,9,3]") + + # pnnx inference + import test_torch_cat_pnnx + b0, b1, b2 = test_torch_cat_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_chunk.py b/tools/pnnx/tests/test_torch_chunk.py new file mode 100644 index 000000000000..36ba8ba1a3ad --- /dev/null +++ b/tools/pnnx/tests/test_torch_chunk.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x0, x1 = torch.chunk(x, chunks=2, dim=1) + y0, y1, y2 = torch.chunk(y, chunks=3, dim=2) + z0, z1, z2, z3, z4 = torch.chunk(z, chunks=5, dim=0) + return x0, x1, y0, y1, y2, z0, z1, z2, z3, z4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_chunk.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_chunk.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_chunk_pnnx + b = test_torch_chunk_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_clamp.py b/tools/pnnx/tests/test_torch_clamp.py new file mode 100644 index 000000000000..70506287cd5a --- /dev/null +++ b/tools/pnnx/tests/test_torch_clamp.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.clamp(x, max=2) + y = torch.clamp(y, min=0) + z = torch.clamp(z, min=-1, max=1) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_clamp.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_clamp.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_clamp_pnnx + b = test_torch_clamp_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_flatten.py b/tools/pnnx/tests/test_torch_flatten.py new file mode 100644 index 000000000000..57fd387eaa35 --- /dev/null +++ b/tools/pnnx/tests/test_torch_flatten.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.flatten(x) + y = torch.flatten(y, start_dim=1, end_dim=-1) + z = torch.flatten(z, start_dim=3, end_dim=4) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_flatten.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_flatten.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_flatten_pnnx + b = test_torch_flatten_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_mean.py b/tools/pnnx/tests/test_torch_mean.py new file mode 100644 index 000000000000..9b6d6ffcc686 --- /dev/null +++ b/tools/pnnx/tests/test_torch_mean.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.mean(x, dim=1, keepdim=False) + y = torch.mean(y, dim=(2,3), keepdim=False) + z = torch.mean(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_mean.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_mean.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_mean_pnnx + b = test_torch_mean_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_permute.py b/tools/pnnx/tests/test_torch_permute.py new file mode 100644 index 000000000000..9098f1ebc4d5 --- /dev/null +++ b/tools/pnnx/tests/test_torch_permute.py @@ -0,0 +1,72 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + if torch.__version__ < '1.9': + x = x.permute(1, 0, 2) + x = x.permute(0, 1, 2) + y = y.permute(2, 3, 1, 0) + y = y.permute(3, 1, 0, 2) + z = z.permute(1, 3, 0, 4, 2) + z = z.permute(0, 2, 4, 3, 1) + else: + x = torch.permute(x, (1, 0, 2)) + x = torch.permute(x, (0, 1, 2)) + y = torch.permute(y, (2, 3, 1, 0)) + y = torch.permute(y, (3, 1, 0, 2)) + z = torch.permute(z, (1, 3, 0, 4, 2)) + z = torch.permute(z, (0, 2, 4, 3, 1)) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_permute.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_permute.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_permute_pnnx + b = test_torch_permute_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_split.py b/tools/pnnx/tests/test_torch_split.py new file mode 100644 index 000000000000..e89ce8036600 --- /dev/null +++ b/tools/pnnx/tests/test_torch_split.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x0, x1 = torch.split(x, split_size_or_sections=2, dim=1) + y0, y1, y2 = torch.split(y, split_size_or_sections=[1,3,5], dim=2) + z0, z1, z2, z3, z4 = torch.split(z, split_size_or_sections=3, dim=0) + return x0, x1, y0, y1, y2, z0, z1, z2, z3, z4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_split.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_split.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_split_pnnx + b = test_torch_split_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_squeeze.py b/tools/pnnx/tests/test_torch_squeeze.py new file mode 100644 index 000000000000..3d0430d63de4 --- /dev/null +++ b/tools/pnnx/tests/test_torch_squeeze.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.squeeze(x, 1) + y = torch.squeeze(y) + z = torch.squeeze(z, 4) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 1, 16) + y = torch.rand(1, 5, 1, 11) + z = torch.rand(14, 8, 5, 9, 1) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_squeeze.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_squeeze.pt inputshape=[1,1,16],[1,5,1,11],[14,8,5,9,1]") + + # pnnx inference + import test_torch_squeeze_pnnx + b = test_torch_squeeze_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_sum.py b/tools/pnnx/tests/test_torch_sum.py new file mode 100644 index 000000000000..fe3ca6e66eef --- /dev/null +++ b/tools/pnnx/tests/test_torch_sum.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.sum(x, dim=1, keepdim=False) + y = torch.sum(y, dim=(2,3), keepdim=False) + z = torch.sum(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_sum.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_sum.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_sum_pnnx + b = test_torch_sum_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_transpose.py b/tools/pnnx/tests/test_torch_transpose.py new file mode 100644 index 000000000000..f0f18dc45d18 --- /dev/null +++ b/tools/pnnx/tests/test_torch_transpose.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.transpose(x, 1, 2) + y = torch.transpose(y, 2, 3) + z = torch.transpose(z, 1, 3) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_transpose.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_transpose.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_transpose_pnnx + b = test_torch_transpose_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_unsqueeze.py b/tools/pnnx/tests/test_torch_unsqueeze.py new file mode 100644 index 000000000000..6857b1d66edb --- /dev/null +++ b/tools/pnnx/tests/test_torch_unsqueeze.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.unsqueeze(x, 0) + x = torch.unsqueeze(x, 1) + y = torch.unsqueeze(y, 2) + y = torch.unsqueeze(y, -1) + z = torch.unsqueeze(z, -2) + z = torch.unsqueeze(z, 3) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_unsqueeze.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_unsqueeze.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_unsqueeze_pnnx + b = test_torch_unsqueeze_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)