-
Notifications
You must be signed in to change notification settings - Fork 117
/
dynamic.py
74 lines (58 loc) · 2.54 KB
/
dynamic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
import sys
CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(1, os.path.join(CURRENT_PATH, '../../'))
import torch
import torch.nn as nn
from tinynn.converter import TFLiteConverter
class SimpleLSTM(nn.Module):
def __init__(self, in_dim, out_dim, layers, num_classes, bidirectional):
super(SimpleLSTM, self).__init__()
num_directions = 2 if bidirectional else 1
self.lstm = torch.nn.LSTM(in_dim, out_dim, layers, bidirectional=bidirectional)
self.fc = torch.nn.Linear(out_dim * num_directions, num_classes)
self.relu = torch.nn.ReLU()
def forward(self, inputs):
out, _ = self.lstm(inputs)
out = self.fc(out)
out = self.relu(out)
return out
def main_worker(args):
model = SimpleLSTM(args.input_size, args.hidden_size, args.num_layers, args.num_classes, args.bidirectional)
# Provide a viable input for the model
dummy_input = torch.rand((args.steps, args.batch_size, args.input_size))
print(model)
with torch.no_grad():
model.eval()
model.cpu()
# The code section below is used to convert the model to the TFLite format
converter = TFLiteConverter(
model,
dummy_input,
tflite_path='out/dynamic_quant_model.tflite',
strict_symmetric_check=True,
quantize_target_type='int8',
# Enable hybrid quantization
hybrid_quantization_from_float=True,
# Enable hybrid per-channel quantization (lower q-loss, but slower)
hybrid_per_channel=False,
# Use asymmetric inputs for hybrid quantization (probably lower q-loss, but a bit slower)
hybrid_asymmetric_inputs=True,
# Enable hybrid per-channel quantization for `Conv2d` and `DepthwiseConv2d`
hybrid_conv=True,
# Enable rewrite for BidirectionLSTMs to UnidirectionalLSTMs
map_bilstm_to_lstm=False,
)
converter.convert()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--steps', type=int, default=20)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--hidden-size', type=int, default=512)
parser.add_argument('--input-size', type=int, default=128)
parser.add_argument('--num-layers', type=int, default=1)
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--bidirectional', action='store_true')
args = parser.parse_args()
main_worker(args)