diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index af374ca4..a27cde55 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -2951,7 +2951,7 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - + def test_gru_unroll_unseparated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -2999,7 +2999,7 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output, check_stride=False) - + def test_gru_batch_first_unroll_unseparated(self): dummy_input = torch.randn(1, 9, 10, dtype=torch.float32) @@ -3051,7 +3051,7 @@ def forward(self, x, hx): dummy_output = model(*dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - + def test_gru_with_state_tensor_unroll_unseparated(self): dummy_input = [ torch.randn(9, 1, 10, dtype=torch.float32), @@ -3103,7 +3103,7 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - + def test_gru_multi_layer_unroll_unseparated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -3155,8 +3155,8 @@ def forward(self, x, hx): dummy_output = model(*dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - - def test_gru_multi_layer_with_state_tensor_unroll_separated(self): + + def test_gru_multi_layer_with_state_tensor_unroll_unseparated(self): dummy_input = [ torch.randn(9, 1, 10, dtype=torch.float32), torch.randn(2, 1, 20, dtype=torch.float32), @@ -3405,7 +3405,7 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - + def test_bigru_unroll_unseparated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) @@ -3453,8 +3453,8 @@ def forward(self, x): dummy_output = model(dummy_input) tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) - - def test_bigru_multi_layer_unroll_separated(self): + + def test_bigru_multi_layer_unroll_unseparated(self): dummy_input = torch.randn(9, 1, 10, dtype=torch.float32) class Model(nn.Module): diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 41caf654..d2277ee2 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -554,7 +554,7 @@ def parse_common( ): assert is_train in (False, 0) self.unroll_rnn = True - + expected_num_params = 2 * num_layers params_step = 2 if has_biases: @@ -684,9 +684,9 @@ def parse_common( for i, t in enumerate(input_ts[::stride]): input_mm_list = [] - + if not self.separated_rnn_gate_calc: - + wir, wiz, win = w_i_list whr, whz, whn = w_r_list bir, biz, bin = b_i_list @@ -702,36 +702,36 @@ def parse_common( ) hidden_mm = self.create_transform_tensor( np.matmul(h.tensor, np.transpose(w_h.tensor, [1, 0])) + b_h.tensor - ) - - ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm])) + ) + + ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm])) ops.append(tfl.FullyConnectedOperator([h, w_h, b_h], [hidden_mm])) left_in = np.split(input_mm.tensor, 3, axis=1) dim_tensor = self.create_attr_tensor(np.array([1], dtype='int32')) splited_left_in = [self.create_transform_tensor(t) for t in left_in] - + ops.append(tfl.SplitOperator([dim_tensor, input_mm], splited_left_in, 3)) right_in = np.split(hidden_mm.tensor, 3, axis=-1) splited_right_in = [self.create_transform_tensor(t) for t in right_in] - + ops.append(tfl.SplitOperator([dim_tensor, hidden_mm], splited_right_in, 3)) - - rgate_left_in, zgate_left_in, ngate_left_in = splited_left_in - rgate_right_in, zgate_right_in, ngate_right_in_b = splited_right_in + + rgate_left_in, zgate_left_in, ngate_left_in = splited_left_in + rgate_right_in, zgate_right_in, ngate_right_in_b = splited_right_in rgate_in = self.create_transform_tensor(rgate_left_in.tensor + rgate_right_in.tensor) ops.append(tfl.AddOperator([rgate_left_in, rgate_right_in], [rgate_in])) - + rgate_out = self.create_transform_tensor( torch.sigmoid(torch.from_numpy(rgate_in.tensor)).numpy() ) ops.append(tfl.LogisticOperator([rgate_in], [rgate_out])) - + zgate_in = self.create_transform_tensor(zgate_left_in.tensor + zgate_right_in.tensor) ops.append(tfl.AddOperator([zgate_left_in, zgate_right_in], [zgate_in])) - + zgate_out = self.create_transform_tensor( torch.sigmoid(torch.from_numpy(zgate_in.tensor)).numpy() ) @@ -743,7 +743,9 @@ def parse_common( ngate_in = self.create_transform_tensor(ngate_left_in.tensor + ngate_right_in.tensor) ops.append(tfl.AddOperator([ngate_left_in, ngate_right_in], [ngate_in])) - ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()) + ngate_out = self.create_transform_tensor( + torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy() + ) ops.append(tfl.TanhOperator([ngate_in], [ngate_out])) constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32)) @@ -765,7 +767,7 @@ def parse_common( h = h_left stacked_hs.append(h) - + else: for j, (w_i, b_i) in enumerate(zip(w_i_list, b_i_list)): @@ -811,7 +813,9 @@ def parse_common( ngate_in = self.create_transform_tensor(input_mm_list[2].tensor + ngate_in_hside.tensor) ops.append(tfl.AddOperator([input_mm_list[2], ngate_in_hside], [ngate_in])) - ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()) + ngate_out = self.create_transform_tensor( + torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy() + ) ops.append(tfl.TanhOperator([ngate_in], [ngate_out])) constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32))