From 328653fd9ed643dc22c1311fc4e1cb94bb824132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Sat, 1 Jun 2024 11:06:42 +0800 Subject: [PATCH 1/7] Update aten.py --- tinynn/converter/operators/torch/aten.py | 172 +++++++++++++++++------ 1 file changed, 127 insertions(+), 45 deletions(-) diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 7dc242da..bd5b07b3 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -554,7 +554,7 @@ def parse_common( ): assert is_train in (False, 0) self.unroll_rnn = True - self.separated_rnn_gate_calc = True + self.separated_rnn_gate_calc = False expected_num_params = 2 * num_layers params_step = 2 @@ -685,73 +685,155 @@ def parse_common( for i, t in enumerate(input_ts[::stride]): input_mm_list = [] - for j, (w_i, b_i) in enumerate(zip(w_i_list, b_i_list)): + + if not self.separated_rnn_gate_calc: + + wir, wiz, win = w_i_list + whr, whz, whn = w_r_list + bir, biz, bin = b_i_list + bhr, bhz, bhn = b_r_list + + w_i = self.create_attr_tensor(np.concatenate([wir.tensor, wiz.tensor, win.tensor], 0)) + w_h = self.create_attr_tensor(np.concatenate([whr.tensor, whz.tensor, whn.tensor], 0)) + b_i = self.create_attr_tensor(np.concatenate([bir.tensor, biz.tensor, bin.tensor], 0)) + b_h = self.create_attr_tensor(np.concatenate([bhr.tensor, bhz.tensor, bhn.tensor], 0)) input_mm = self.create_transform_tensor( np.matmul(t.tensor, np.transpose(w_i.tensor, [1, 0])) + b_i.tensor ) - ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm])) - input_mm_list.append(input_mm) + hidden_mm = self.create_transform_tensor( + np.matmul(h.tensor, np.transpose(w_h.tensor, [1, 0])) + b_h.tensor + ) + + ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm])) + ops.append(tfl.FullyConnectedOperator([t, w_h, b_h], [hidden_mm])) + + left_in = np.split(input_mm.tensor, 3, axis=1) + dim_tensor = self.create_attr_tensor(np.array([1], dtype='int32')) + splited_left_in = [self.create_transform_tensor(t) for t in left_in] + + ops.append(tfl.SplitOperator([dim_tensor, input_mm], splited_left_in, 3)) + + right_in = np.split(hidden_mm.tensor, 3, axis=-1) + splited_right_in = [self.create_transform_tensor(t) for t in right_in] + + ops.append(tfl.SplitOperator([dim_tensor, hidden_mm], splited_right_in, 3)) + + rgate_left_in, zgate_left_in, ngate_left_in = splited_left_in + rgate_right_in, zgate_right_in, ngate_right_in_b = splited_right_in + + rgate_in = self.create_transform_tensor(rgate_left_in.tensor + rgate_right_in.tensor) + ops.append(tfl.AddOperator([rgate_left_in, rgate_right_in], [rgate_in])) + + rgate_out = self.create_transform_tensor( + torch.sigmoid(torch.from_numpy(rgate_in.tensor)).numpy() + ) + ops.append(tfl.LogisticOperator([rgate_in], [rgate_out])) + + zgate_in = self.create_transform_tensor(zgate_left_in.tensor + zgate_right_in.tensor) + ops.append(tfl.AddOperator([zgate_left_in, zgate_right_in], [zgate_in])) + + zgate_out = self.create_transform_tensor( + torch.sigmoid(torch.from_numpy(zgate_in.tensor)).numpy() + ) + ops.append(tfl.LogisticOperator([zgate_in], [zgate_out])) - if i != 0 or compute_h: + ngate_right_in = self.create_transform_tensor(rgate_out.tensor * ngate_right_in_b.tensor) + ops.append(tfl.MulOperator([rgate_out, ngate_right_in_b], [ngate_right_in])) - hidden_mm_list = [] - for j, (w_r, b_r) in enumerate(zip(w_r_list, b_r_list)): + ngate_in = self.create_transform_tensor(ngate_left_in.tensor + ngate_right_in.tensor) + ops.append(tfl.AddOperator([ngate_left_in, ngate_right_in], [ngate_in])) - hidden_mm = self.create_transform_tensor( - np.matmul(h.tensor, np.transpose(w_r.tensor, [1, 0])) + b_r.tensor - ) - ops.append(tfl.FullyConnectedOperator([h, w_r, b_r], [hidden_mm])) - hidden_mm_list.append(hidden_mm) + ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()) + ops.append(tfl.TanhOperator([ngate_in], [ngate_out])) + + constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32)) + + h_left_0 = self.create_transform_tensor(constant_tensor.tensor - zgate_out.tensor) + ops.append(tfl.SubOperator([constant_tensor, zgate_out], [h_left_0])) + + h_left = self.create_transform_tensor(h_left_0.tensor * ngate_out.tensor) + ops.append(tfl.MulOperator([h_left_0, ngate_out], [h_left])) + + if i != 0 or compute_h: + h_right = self.create_transform_tensor(zgate_out.tensor * h.tensor) + ops.append(tfl.MulOperator([zgate_out, h], [h_right])) + + h = self.create_transform_tensor(h_left.tensor + h_right.tensor) + ops.append(tfl.AddOperator([h_left, h_right], [h])) + + elif i == 0 and not compute_h: + h = h_left + + stacked_hs.append(h) + else: - hidden_mm_list = b_r_list + for j, (w_i, b_i) in enumerate(zip(w_i_list, b_i_list)): + + input_mm = self.create_transform_tensor( + np.matmul(t.tensor, np.transpose(w_i.tensor, [1, 0])) + b_i.tensor + ) + ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm])) + input_mm_list.append(input_mm) - # calculate r,z,n gates - rgate_in = self.create_transform_tensor(input_mm_list[0].tensor + hidden_mm_list[0].tensor) - ops.append(tfl.AddOperator([input_mm_list[0], hidden_mm_list[0]], [rgate_in])) + if i != 0 or compute_h: - zgate_in = self.create_transform_tensor(input_mm_list[1].tensor + hidden_mm_list[1].tensor) - ops.append(tfl.AddOperator([input_mm_list[1], hidden_mm_list[1]], [zgate_in])) + hidden_mm_list = [] + for j, (w_r, b_r) in enumerate(zip(w_r_list, b_r_list)): - zgate_out = self.create_transform_tensor( - torch.sigmoid(torch.from_numpy(zgate_in.tensor)).numpy() - ) + hidden_mm = self.create_transform_tensor( + np.matmul(h.tensor, np.transpose(w_r.tensor, [1, 0])) + b_r.tensor + ) + ops.append(tfl.FullyConnectedOperator([h, w_r, b_r], [hidden_mm])) + hidden_mm_list.append(hidden_mm) + else: + hidden_mm_list = b_r_list - ops.append(tfl.LogisticOperator([zgate_in], [zgate_out])) + # calculate r,z,n gates + rgate_in = self.create_transform_tensor(input_mm_list[0].tensor + hidden_mm_list[0].tensor) + ops.append(tfl.AddOperator([input_mm_list[0], hidden_mm_list[0]], [rgate_in])) - rgate_out = self.create_transform_tensor( - torch.sigmoid(torch.from_numpy(rgate_in.tensor)).numpy() - ) - ops.append(tfl.LogisticOperator([rgate_in], [rgate_out])) + zgate_in = self.create_transform_tensor(input_mm_list[1].tensor + hidden_mm_list[1].tensor) + ops.append(tfl.AddOperator([input_mm_list[1], hidden_mm_list[1]], [zgate_in])) + + zgate_out = self.create_transform_tensor( + torch.sigmoid(torch.from_numpy(zgate_in.tensor)).numpy() + ) + ops.append(tfl.LogisticOperator([zgate_in], [zgate_out])) - ngate_in_hside = self.create_transform_tensor(rgate_out.tensor * hidden_mm_list[2].tensor) - ops.append(tfl.MulOperator([rgate_out, hidden_mm_list[2]], [ngate_in_hside])) + rgate_out = self.create_transform_tensor( + torch.sigmoid(torch.from_numpy(rgate_in.tensor)).numpy() + ) + ops.append(tfl.LogisticOperator([rgate_in], [rgate_out])) - ngate_in = self.create_transform_tensor(input_mm_list[2].tensor + ngate_in_hside.tensor) - ops.append(tfl.AddOperator([input_mm_list[2], ngate_in_hside], [ngate_in])) + ngate_in_hside = self.create_transform_tensor(rgate_out.tensor * hidden_mm_list[2].tensor) + ops.append(tfl.MulOperator([rgate_out, hidden_mm_list[2]], [ngate_in_hside])) - ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()) - ops.append(tfl.TanhOperator([ngate_in], [ngate_out])) + ngate_in = self.create_transform_tensor(input_mm_list[2].tensor + ngate_in_hside.tensor) + ops.append(tfl.AddOperator([input_mm_list[2], ngate_in_hside], [ngate_in])) - constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32)) + ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()) + ops.append(tfl.TanhOperator([ngate_in], [ngate_out])) - h_left_0 = self.create_transform_tensor(constant_tensor.tensor - zgate_out.tensor) - ops.append(tfl.SubOperator([constant_tensor, zgate_out], [h_left_0])) + constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32)) - h_left = self.create_transform_tensor(h_left_0.tensor * ngate_out.tensor) - ops.append(tfl.MulOperator([h_left_0, ngate_out], [h_left])) + h_left_0 = self.create_transform_tensor(constant_tensor.tensor - zgate_out.tensor) + ops.append(tfl.SubOperator([constant_tensor, zgate_out], [h_left_0])) - if i != 0 or compute_h: - h_right = self.create_transform_tensor(zgate_out.tensor * h.tensor) - ops.append(tfl.MulOperator([zgate_out, h], [h_right])) + h_left = self.create_transform_tensor(h_left_0.tensor * ngate_out.tensor) + ops.append(tfl.MulOperator([h_left_0, ngate_out], [h_left])) - h = self.create_transform_tensor(h_left.tensor + h_right.tensor) - ops.append(tfl.AddOperator([h_left, h_right], [h])) + if i != 0 or compute_h: + h_right = self.create_transform_tensor(zgate_out.tensor * h.tensor) + ops.append(tfl.MulOperator([zgate_out, h], [h_right])) - elif i == 0 and not compute_h: - h = h_left + h = self.create_transform_tensor(h_left.tensor + h_right.tensor) + ops.append(tfl.AddOperator([h_left, h_right], [h])) - stacked_hs.append(h) + elif i == 0 and not compute_h: + h = h_left + + stacked_hs.append(h) tf_out_state_tensors[0].append(h) output_ts.extend(stacked_hs[::stride]) From 32a77f7d601fbc271baa4ec948e3ce3bf7ba73e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Mon, 3 Jun 2024 11:09:40 +0800 Subject: [PATCH 2/7] Update aten.py --- tinynn/converter/operators/torch/aten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index bd5b07b3..58da8a30 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -706,7 +706,7 @@ def parse_common( ) ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm])) - ops.append(tfl.FullyConnectedOperator([t, w_h, b_h], [hidden_mm])) + ops.append(tfl.FullyConnectedOperator([h, w_h, b_h], [hidden_mm])) left_in = np.split(input_mm.tensor, 3, axis=1) dim_tensor = self.create_attr_tensor(np.array([1], dtype='int32')) From 4b4932309a6895dd18e1ac181023380bcdd6e0cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Mon, 3 Jun 2024 11:09:58 +0800 Subject: [PATCH 3/7] Update converter_op_test.py --- tests/converter_op_test.py | 186 +++++++++++++++++++++++++++++++++++-- 1 file changed, 176 insertions(+), 10 deletions(-) diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index a1b284e0..af374ca4 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -2951,6 +2951,30 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + + def test_gru_unroll_unseparated(self): + dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gru = nn.GRU(10, 20) + + def forward(self, x): + return self.gru(x)[0] + + model = Model() + model.eval() + + model_path = get_model_path() + converter = TFLiteConverter( + model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False + ) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) def test_gru_batch_first_unroll_separated(self): dummy_input = torch.randn(1, 9, 10, dtype=torch.float32) @@ -2975,6 +2999,30 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output, check_stride=False) + + def test_gru_batch_first_unroll_unseparated(self): + dummy_input = torch.randn(1, 9, 10, dtype=torch.float32) + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gru = nn.GRU(10, 20, batch_first=True) + + def forward(self, x): + return self.gru(x)[0] + + model = Model() + model.eval() + + model_path = get_model_path() + converter = TFLiteConverter( + model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False + ) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output, check_stride=False) def test_gru_with_state_tensor_unroll_separated(self): dummy_input = [ @@ -3003,6 +3051,34 @@ def forward(self, x, hx): dummy_output = model(*dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + + def test_gru_with_state_tensor_unroll_unseparated(self): + dummy_input = [ + torch.randn(9, 1, 10, dtype=torch.float32), + torch.randn(1, 1, 20, dtype=torch.float32), + ] + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gru = nn.GRU(10, 20) + + def forward(self, x, hx): + gru, hx = self.gru(x, hx) + return gru, hx + + model = Model() + model.eval() + + model_path = get_model_path() + converter = TFLiteConverter( + model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False + ) + converter.convert() + + dummy_output = model(*dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) def test_gru_multi_layer_unroll_separated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -3027,6 +3103,30 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + + def test_gru_multi_layer_unroll_unseparated(self): + dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gru = nn.GRU(10, 20, 2) + + def forward(self, x): + return self.gru(x)[0] + + model = Model() + model.eval() + + model_path = get_model_path() + converter = TFLiteConverter( + model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False + ) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) def test_gru_multi_layer_with_state_tensor_unroll_separated(self): dummy_input = [ @@ -3055,6 +3155,34 @@ def forward(self, x, hx): dummy_output = model(*dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + + def test_gru_multi_layer_with_state_tensor_unroll_separated(self): + dummy_input = [ + torch.randn(9, 1, 10, dtype=torch.float32), + torch.randn(2, 1, 20, dtype=torch.float32), + ] + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gru = nn.GRU(10, 20, 2) + + def forward(self, x, hx): + gru, hx = self.gru(x, hx) + return gru, hx + + model = Model() + model.eval() + + model_path = get_model_path() + converter = TFLiteConverter( + model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False + ) + converter.convert() + + dummy_output = model(*dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) def test_bigru(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -3277,6 +3405,30 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + + def test_bigru_unroll_unseparated(self): + dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gru = nn.GRU(10, 20, bidirectional=True) + + def forward(self, x): + return self.gru(x)[0] + + model = Model() + model.eval() + + model_path = get_model_path() + converter = TFLiteConverter( + model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False + ) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) def test_bigru_multi_layer_unroll_separated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -3301,6 +3453,30 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + + def test_bigru_multi_layer_unroll_separated(self): + dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gru = nn.GRU(10, 20, 2, bidirectional=True) + + def forward(self, x): + return self.gru(x)[0] + + model = Model() + model.eval() + + model_path = get_model_path() + converter = TFLiteConverter( + model, dummy_input, model_path, nchw_transpose=False, unroll_rnn=True, separated_rnn_gate_calc=False + ) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) def test_lstm(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -4343,11 +4519,6 @@ def model(x): LooseVersion(torch.__version__) < LooseVersion('1.7.0'), "torch.Tensor.scatter_ cannot take scalar inputs", ) - @unittest.skipIf( - LooseVersion(torch.__version__) >= LooseVersion('1.7.0') - and LooseVersion(torch.__version__) < LooseVersion('1.8.0'), - "torch.Tensor.scatter_ with scalar inputs fails", - ) @unittest.skipIf( LooseVersion(torch.__version__) >= LooseVersion('1.12.0') and LooseVersion(torch.__version__) < LooseVersion('1.13.0'), @@ -4373,11 +4544,6 @@ def model(x): LooseVersion(torch.__version__) < LooseVersion('1.7.0'), "torch.Tensor.scatter_ cannot take scalar inputs", ) - @unittest.skipIf( - LooseVersion(torch.__version__) >= LooseVersion('1.7.0') - and LooseVersion(torch.__version__) < LooseVersion('1.8.0'), - "torch.Tensor.scatter_ with scalar inputs fails", - ) @unittest.skipIf( LooseVersion(torch.__version__) >= LooseVersion('1.12.0') and LooseVersion(torch.__version__) < LooseVersion('1.13.0'), From 9aed76d716aa9efbe1b361e1fc059886c32ce29e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Mon, 3 Jun 2024 11:12:05 +0800 Subject: [PATCH 4/7] Update aten.py --- tinynn/converter/operators/torch/aten.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 58da8a30..41caf654 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -554,8 +554,7 @@ def parse_common( ): assert is_train in (False, 0) self.unroll_rnn = True - self.separated_rnn_gate_calc = False - + expected_num_params = 2 * num_layers params_step = 2 if has_biases: From 7c3af032f76f303d5d837bddebfebbc996c11fcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Mon, 3 Jun 2024 04:51:15 +0000 Subject: [PATCH 5/7] minor fix --- docs/FAQ.md | 2 +- docs/FAQ_zh-CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/FAQ.md b/docs/FAQ.md index c82693bf..40a71bd0 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -225,7 +225,7 @@ You may also try out static quantization for LSTMs when you have PyTorch 1.13+. #### What if my model runs slower when dynamic quantization is enabled? Please refer to [dynamic_with_selection.py](../examples/converter/dynamic_with_selection.py) for selective dynamic quantization. -#### I need LSTMs with separated gate calculation when `unroll_rnn=True`. +#### I need LSTM/GRUs with separated gate calculation when `unroll_rnn=True`. Please set `separated_rnn_gate_calc=True`. #### How to add state inputs and outputs for LSTMs/GRUs/RNNs with `unroll_rnn=True`? diff --git a/docs/FAQ_zh-CN.md b/docs/FAQ_zh-CN.md index 9d1462f7..2c8d5259 100644 --- a/docs/FAQ_zh-CN.md +++ b/docs/FAQ_zh-CN.md @@ -225,7 +225,7 @@ Note: 这些状态变量都是二维的,维度为`[batch_size, hidden_size或 #### 我的模型开了动态量化变得更慢了? 请参考 [dynamic_with_selection.py](../examples/converter/dynamic_with_selection.py) 选择性的开启动态量化。 -#### 在设置了`unroll_rnn=True`后,LSTM中多个门的计算被融合了。有没有办法分开? +#### 在设置了`unroll_rnn=True`后,LSTM/GRU中多个门的计算被融合了。有没有办法分开? 尝试设置`separated_rnn_gate_calc=True`。 #### 在`unroll_rnn=True`的情况下,怎么为包含LSTM、RNN和GRU的网络添加状态输入输出? From cb1a5fbcdd2808af7b3b032261c9ad7dcffb38d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Mon, 3 Jun 2024 05:01:04 +0000 Subject: [PATCH 6/7] minor fix --- tests/converter_op_test.py | 18 +++++------ tinynn/converter/operators/torch/aten.py | 38 +++++++++++++----------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index af374ca4..a27cde55 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -2951,7 +2951,7 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - + def test_gru_unroll_unseparated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -2999,7 +2999,7 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output, check_stride=False) - + def test_gru_batch_first_unroll_unseparated(self): dummy_input = torch.randn(1, 9, 10, dtype=torch.float32) @@ -3051,7 +3051,7 @@ def forward(self, x, hx): dummy_output = model(*dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - + def test_gru_with_state_tensor_unroll_unseparated(self): dummy_input = [ torch.randn(9, 1, 10, dtype=torch.float32), @@ -3103,7 +3103,7 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - + def test_gru_multi_layer_unroll_unseparated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -3155,8 +3155,8 @@ def forward(self, x, hx): dummy_output = model(*dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - - def test_gru_multi_layer_with_state_tensor_unroll_separated(self): + + def test_gru_multi_layer_with_state_tensor_unroll_unseparated(self): dummy_input = [ torch.randn(9, 1, 10, dtype=torch.float32), torch.randn(2, 1, 20, dtype=torch.float32), @@ -3405,7 +3405,7 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - + def test_bigru_unroll_unseparated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -3453,8 +3453,8 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - - def test_bigru_multi_layer_unroll_separated(self): + + def test_bigru_multi_layer_unroll_unseparated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) class Model(nn.Module): diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 41caf654..d2277ee2 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -554,7 +554,7 @@ def parse_common( ): assert is_train in (False, 0) self.unroll_rnn = True - + expected_num_params = 2 * num_layers params_step = 2 if has_biases: @@ -684,9 +684,9 @@ def parse_common( for i, t in enumerate(input_ts[::stride]): input_mm_list = [] - + if not self.separated_rnn_gate_calc: - + wir, wiz, win = w_i_list whr, whz, whn = w_r_list bir, biz, bin = b_i_list @@ -702,36 +702,36 @@ def parse_common( ) hidden_mm = self.create_transform_tensor( np.matmul(h.tensor, np.transpose(w_h.tensor, [1, 0])) + b_h.tensor - ) - - ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm])) + ) + + ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm])) ops.append(tfl.FullyConnectedOperator([h, w_h, b_h], [hidden_mm])) left_in = np.split(input_mm.tensor, 3, axis=1) dim_tensor = self.create_attr_tensor(np.array([1], dtype='int32')) splited_left_in = [self.create_transform_tensor(t) for t in left_in] - + ops.append(tfl.SplitOperator([dim_tensor, input_mm], splited_left_in, 3)) right_in = np.split(hidden_mm.tensor, 3, axis=-1) splited_right_in = [self.create_transform_tensor(t) for t in right_in] - + ops.append(tfl.SplitOperator([dim_tensor, hidden_mm], splited_right_in, 3)) - - rgate_left_in, zgate_left_in, ngate_left_in = splited_left_in - rgate_right_in, zgate_right_in, ngate_right_in_b = splited_right_in + + rgate_left_in, zgate_left_in, ngate_left_in = splited_left_in + rgate_right_in, zgate_right_in, ngate_right_in_b = splited_right_in rgate_in = self.create_transform_tensor(rgate_left_in.tensor + rgate_right_in.tensor) ops.append(tfl.AddOperator([rgate_left_in, rgate_right_in], [rgate_in])) - + rgate_out = self.create_transform_tensor( torch.sigmoid(torch.from_numpy(rgate_in.tensor)).numpy() ) ops.append(tfl.LogisticOperator([rgate_in], [rgate_out])) - + zgate_in = self.create_transform_tensor(zgate_left_in.tensor + zgate_right_in.tensor) ops.append(tfl.AddOperator([zgate_left_in, zgate_right_in], [zgate_in])) - + zgate_out = self.create_transform_tensor( torch.sigmoid(torch.from_numpy(zgate_in.tensor)).numpy() ) @@ -743,7 +743,9 @@ def parse_common( ngate_in = self.create_transform_tensor(ngate_left_in.tensor + ngate_right_in.tensor) ops.append(tfl.AddOperator([ngate_left_in, ngate_right_in], [ngate_in])) - ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()) + ngate_out = self.create_transform_tensor( + torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy() + ) ops.append(tfl.TanhOperator([ngate_in], [ngate_out])) constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32)) @@ -765,7 +767,7 @@ def parse_common( h = h_left stacked_hs.append(h) - + else: for j, (w_i, b_i) in enumerate(zip(w_i_list, b_i_list)): @@ -811,7 +813,9 @@ def parse_common( ngate_in = self.create_transform_tensor(input_mm_list[2].tensor + ngate_in_hside.tensor) ops.append(tfl.AddOperator([input_mm_list[2], ngate_in_hside], [ngate_in])) - ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()) + ngate_out = self.create_transform_tensor( + torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy() + ) ops.append(tfl.TanhOperator([ngate_in], [ngate_out])) constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32)) From a151d61b92aceb00c211f56fea4ba9117fed3ee3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Mon, 3 Jun 2024 14:26:41 +0800 Subject: [PATCH 7/7] Update converter_op_test.py --- tests/converter_op_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index a27cde55..d8b59061 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -4519,6 +4519,11 @@ def model(x): LooseVersion(torch.__version__) < LooseVersion('1.7.0'), "torch.Tensor.scatter_ cannot take scalar inputs", ) + @unittest.skipIf( + LooseVersion(torch.__version__) >= LooseVersion('1.7.0') + and LooseVersion(torch.__version__) < LooseVersion('1.8.0'), + "torch.Tensor.scatter_ with scalar inputs fails", + ) @unittest.skipIf( LooseVersion(torch.__version__) >= LooseVersion('1.12.0') and LooseVersion(torch.__version__) < LooseVersion('1.13.0'), @@ -4544,6 +4549,11 @@ def model(x): LooseVersion(torch.__version__) < LooseVersion('1.7.0'), "torch.Tensor.scatter_ cannot take scalar inputs", ) + @unittest.skipIf( + LooseVersion(torch.__version__) >= LooseVersion('1.7.0') + and LooseVersion(torch.__version__) < LooseVersion('1.8.0'), + "torch.Tensor.scatter_ with scalar inputs fails", + ) @unittest.skipIf( LooseVersion(torch.__version__) >= LooseVersion('1.12.0') and LooseVersion(torch.__version__) < LooseVersion('1.13.0'),