Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.0.1 #1

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Lint

on:
push:
branches:
- main
pull_request:

jobs:
lint-python:
name: Pylint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
python -m pip install -e .
- name: Analysing the code with pylint
run: |
pylint --output-format=colorized $(git ls-files '*.py')

lint-python-format:
name: Python format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: "3.9"
- uses: psf/black@stable
with:
options: "--check --diff"
- uses: isort/isort-action@master
with:
configuration:
--check
--diff
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__pycache__/
*.egg-info/
*.egg

.idea*
55 changes: 55 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
repos:
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args:
[
"--force-single-line-imports",
"--ensure-newline-before-comments",
"--line-length=120",
]
- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
hooks:
- id: pyupgrade
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.3
hooks:
- id: docformatter
additional_dependencies: [tomli]
args:
[
"--in-place",
"--config",
"pyproject.toml",
]
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
hooks:
- id: mdformat
additional_dependencies:
- mdformat-gfm
- mdformat-black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-yaml
- id: check-toml
- id: check-json
- id: check-ast
- id: fix-byte-order-marker
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: detect-private-key
- id: end-of-file-fixer
- id: detect-private-key
- id: no-commit-to-branch
args: ["-b=main"]
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,24 @@
# ti-vit
# TI-ViT

The repository contains script for exporting PyTorch VIT model to ONNX format in the form that compatible with
[edgeai-tidl-tools](https://github.com/TexasInstruments/edgeai-tidl-tools) (version 8.6.0.5).

## Installation

To install export script run the following command:
```commandline
pip3 install git+https://github.com/ENOT-AutoDL/ti-vit.git@main
```

## Examples

To export the model version with maximum performance, run the following command:
```commandline
export-ti-vit -o npu-max-perf.onnx -t npu-max-perf
```

To export the model version with minimal loss of accuracy, run the following command:
```commandline
export-ti-vit -o npu-max-acc.onnx -t npu-max-acc
```

62 changes: 62 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
[project]
name = 'ti-vit'
version = '0.0.1'
dependencies = [
'torch==1.13.1',
'torchvision==0.14.1',
]

[project.scripts]
export-ti-vit = "ti_vit.export:export_ti_compatible_vit"

[tool.black]
line-length = 120
target-version = ["py38", "py39"]
include = '\.pyi?$'

[tool.isort]
profile = "black"
line_length = 120
ensure_newline_before_comments = true
force_single_line = true

[tool.nbqa.mutate]
pyupgrade = 1

[tool.nbqa.addopts]
pyupgrade = ["--py38-plus"]

[tool.docformatter]
recursive = true
wrap-summaries = 0
wrap-descriptions = 0
blank = true
black = true
pre-summary-newline = true

[tool.pylint.format]
max-line-length = 120

[tool.pylint.design]
max-args = 12
max-locals = 30
max-attributes = 20
min-public-methods = 0

[tool.pylint.typecheck]
generated-members = ["torch.*"]

[tool.pylint.messages_control]
disable = [
"logging-fstring-interpolation",
"missing-module-docstring",
"unnecessary-pass",
]

[tool.pylint.BASIC]
good-names = ["B", "N", "C"]

[tool.pyright]
reportMissingImports = false
reportMissingTypeStubs = false
reportWildcardImportFromLibrary = false
2 changes: 2 additions & 0 deletions src/ti_vit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ti_vit.model import TICompatibleVitOrtMaxAcc
from ti_vit.model import TICompatibleVitOrtMaxPerf
167 changes: 167 additions & 0 deletions src/ti_vit/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import typing
from enum import Enum
from typing import Tuple

import torch
from torch import nn

from ti_vit.common import copy_weights
from ti_vit.common import sync_device_and_mode


class AttentionType(Enum):
"""
Type of attention block.

- CONV_CONV - qkv projection and output projection is a convolution with 1x1 kernel
- CONV_LINEAR - qkv projection is a convolution with 1x1 kernel, output projection is linear

"""

CONV_CONV = "CONV_CONV"
CONV_LINEAR = "CONV_LINEAR"


class TICompatibleAttention(nn.Module):
"""TI compatible attention block."""

def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attention_type: AttentionType = AttentionType.CONV_LINEAR,
):
"""
Parameters
----------
dim : int
Total dimension of the model.
num_heads : int
Number of parallel attention heads.
qkv_bias : bool
If True, adds a learnable bias to the qkv projection. Default value is False.
attention_type : AttentionType
Type of attention block (see ``AttentionType`` enum documentation).
"""
super().__init__()

if dim % num_heads != 0:
raise ValueError(f'"dim"={dim} should be divisible by "num_heads"={num_heads}')

self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5

if attention_type == AttentionType.CONV_CONV:
self.qkv_proj = nn.Conv2d(in_channels=dim, out_channels=dim * 3, kernel_size=(1, 1), bias=qkv_bias)
self.out_proj = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(1, 1))
elif attention_type == AttentionType.CONV_LINEAR:
self.qkv_proj = nn.Conv2d(in_channels=dim, out_channels=dim * 3, kernel_size=(1, 1), bias=qkv_bias)
self.out_proj = nn.Linear(in_features=dim, out_features=dim)
else:
raise ValueError(f'Got unknown attention_type "{attention_type}"')

self._attention_type = attention_type

def forward( # pylint: disable=missing-function-docstring
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
need_weights: bool = True,
) -> Tuple[torch.Tensor, None]:
del key, value

assert not need_weights

x = query
B, N, C = x.shape

# (B, N, C) -> (B, N, C, 1) -> (B, C, N, 1)
x = x.unsqueeze(3).permute(0, 2, 1, 3)

qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, 3, C, N)
q, k, v = qkv.split(1, dim=1)

# (B, 1, C, N) -> (B, H, C//H, N) -> (B, H, N, C//H)
q = q.reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2)
# (B, 1, C, N) -> (B, H, C//H, N)
k = k.reshape(B, self.num_heads, C // self.num_heads, N)
# (B, 1, C, N) -> (B, H, C//H, N) -> (B, H, N, C//H)
v = v.reshape(B, self.num_heads, C // self.num_heads, N).permute(0, 1, 3, 2)

attn = (q @ k) * self.scale
attn = attn.softmax(dim=-1)

x = attn @ v

if self._attention_type == AttentionType.CONV_CONV:
# (B, H, N, C//H) -> (B, H, C//H, N) -> (B, C, N, 1)
x = x.permute(0, 1, 3, 2).reshape(B, C, N, 1)
x = self.out_proj(x)
x = x.permute(0, 2, 1, 3)
x = x.squeeze(3)
else:
# (B, H, N, C//H) -> (B, N, H, C//H) -> (B, N, C)
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
x = self.out_proj(x)

return x, None

@classmethod
def from_module(
cls,
vit_attn: nn.Module,
attention_type: AttentionType = AttentionType.CONV_CONV,
) -> "TICompatibleAttention":
"""
Create TI compatible attention block from common ViT attention block.

Parameters
----------
vit_attn : nn.Module
Source block.
attention_type : AttentionType
Attention type (see ``AttentionType`` enum documentation).

Returns
-------
TICompatibleAttention
Instance of ``TICompatibleAttention`` with appropriate weights, device and training mode.

"""
if hasattr(vit_attn, "qkv"):
qkv_proj = typing.cast(nn.Linear, vit_attn.qkv)
out_proj = typing.cast(nn.Linear, vit_attn.proj)
else:
in_proj_weight = typing.cast(nn.Parameter, vit_attn.in_proj_weight)
out_features, in_features = in_proj_weight.shape
qkv_proj = nn.Linear(
in_features=in_features,
out_features=out_features,
bias=hasattr(vit_attn, "in_proj_bias"),
device=in_proj_weight.device,
dtype=in_proj_weight.dtype,
)
qkv_proj.weight = in_proj_weight
qkv_proj.bias = vit_attn.in_proj_bias # pyright: ignore[reportAttributeAccessIssue]

out_proj = typing.cast(nn.Linear, vit_attn.out_proj)

ti_compatible_attn = cls(
dim=qkv_proj.in_features,
num_heads=typing.cast(int, vit_attn.num_heads),
qkv_bias=qkv_proj.bias is not None,
attention_type=attention_type,
)
sync_device_and_mode(src=vit_attn, dst=ti_compatible_attn)

copy_weights(src=qkv_proj, dst=ti_compatible_attn.qkv_proj)
copy_weights(src=out_proj, dst=ti_compatible_attn.out_proj)

if hasattr(vit_attn, "scale"):
ti_compatible_attn.scale = vit_attn.scale

return ti_compatible_attn
Loading
Loading