Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Juelianqvq committed Jun 3, 2024
1 parent 7c3af03 commit cb1a5fb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 26 deletions.
18 changes: 9 additions & 9 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2951,7 +2951,7 @@ def forward(self, x):
dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand Down Expand Up @@ -2999,7 +2999,7 @@ def forward(self, x):
dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False)

def test_gru_batch_first_unroll_unseparated(self):
dummy_input = torch.randn(1, 9, 10, dtype=torch.float32)

Expand Down Expand Up @@ -3051,7 +3051,7 @@ def forward(self, x, hx):
dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_with_state_tensor_unroll_unseparated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
Expand Down Expand Up @@ -3103,7 +3103,7 @@ def forward(self, x):
dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gru_multi_layer_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand Down Expand Up @@ -3155,8 +3155,8 @@ def forward(self, x, hx):
dummy_output = model(*dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)
def test_gru_multi_layer_with_state_tensor_unroll_separated(self):

def test_gru_multi_layer_with_state_tensor_unroll_unseparated(self):
dummy_input = [
torch.randn(9, 1, 10, dtype=torch.float32),
torch.randn(2, 1, 20, dtype=torch.float32),
Expand Down Expand Up @@ -3405,7 +3405,7 @@ def forward(self, x):
dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_bigru_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

Expand Down Expand Up @@ -3453,8 +3453,8 @@ def forward(self, x):
dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)
def test_bigru_multi_layer_unroll_separated(self):

def test_bigru_multi_layer_unroll_unseparated(self):
dummy_input = torch.randn(9, 1, 10, dtype=torch.float32)

class Model(nn.Module):
Expand Down
38 changes: 21 additions & 17 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def parse_common(
):
assert is_train in (False, 0)
self.unroll_rnn = True

expected_num_params = 2 * num_layers
params_step = 2
if has_biases:
Expand Down Expand Up @@ -684,9 +684,9 @@ def parse_common(
for i, t in enumerate(input_ts[::stride]):

input_mm_list = []

if not self.separated_rnn_gate_calc:

wir, wiz, win = w_i_list
whr, whz, whn = w_r_list
bir, biz, bin = b_i_list
Expand All @@ -702,36 +702,36 @@ def parse_common(
)
hidden_mm = self.create_transform_tensor(
np.matmul(h.tensor, np.transpose(w_h.tensor, [1, 0])) + b_h.tensor
)
ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm]))
)

ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm]))
ops.append(tfl.FullyConnectedOperator([h, w_h, b_h], [hidden_mm]))

left_in = np.split(input_mm.tensor, 3, axis=1)
dim_tensor = self.create_attr_tensor(np.array([1], dtype='int32'))
splited_left_in = [self.create_transform_tensor(t) for t in left_in]

ops.append(tfl.SplitOperator([dim_tensor, input_mm], splited_left_in, 3))

right_in = np.split(hidden_mm.tensor, 3, axis=-1)
splited_right_in = [self.create_transform_tensor(t) for t in right_in]

ops.append(tfl.SplitOperator([dim_tensor, hidden_mm], splited_right_in, 3))
rgate_left_in, zgate_left_in, ngate_left_in = splited_left_in
rgate_right_in, zgate_right_in, ngate_right_in_b = splited_right_in

rgate_left_in, zgate_left_in, ngate_left_in = splited_left_in
rgate_right_in, zgate_right_in, ngate_right_in_b = splited_right_in

rgate_in = self.create_transform_tensor(rgate_left_in.tensor + rgate_right_in.tensor)
ops.append(tfl.AddOperator([rgate_left_in, rgate_right_in], [rgate_in]))

rgate_out = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(rgate_in.tensor)).numpy()
)
ops.append(tfl.LogisticOperator([rgate_in], [rgate_out]))

zgate_in = self.create_transform_tensor(zgate_left_in.tensor + zgate_right_in.tensor)
ops.append(tfl.AddOperator([zgate_left_in, zgate_right_in], [zgate_in]))

zgate_out = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(zgate_in.tensor)).numpy()
)
Expand All @@ -743,7 +743,9 @@ def parse_common(
ngate_in = self.create_transform_tensor(ngate_left_in.tensor + ngate_right_in.tensor)
ops.append(tfl.AddOperator([ngate_left_in, ngate_right_in], [ngate_in]))

ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy())
ngate_out = self.create_transform_tensor(
torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()
)
ops.append(tfl.TanhOperator([ngate_in], [ngate_out]))

constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32))
Expand All @@ -765,7 +767,7 @@ def parse_common(
h = h_left

stacked_hs.append(h)

else:
for j, (w_i, b_i) in enumerate(zip(w_i_list, b_i_list)):

Expand Down Expand Up @@ -811,7 +813,9 @@ def parse_common(
ngate_in = self.create_transform_tensor(input_mm_list[2].tensor + ngate_in_hside.tensor)
ops.append(tfl.AddOperator([input_mm_list[2], ngate_in_hside], [ngate_in]))

ngate_out = self.create_transform_tensor(torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy())
ngate_out = self.create_transform_tensor(
torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()
)
ops.append(tfl.TanhOperator([ngate_in], [ngate_out]))

constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32))
Expand Down

0 comments on commit cb1a5fb

Please sign in to comment.