Skip to content

Commit

Permalink
[converter] refine reshape transpose pass (#255)
Browse files Browse the repository at this point in the history
* [converter] refine reshape transpose pass

* Cleanup
  • Loading branch information
peterjc123 authored Sep 21, 2023
1 parent 1e093bd commit f19e3db
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,8 +1115,11 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int:
actions = []
remove_edges = []
remove_vertices = []
processed_nodes = set()
num_actions = 0
for node in unique_nodes:
pending_processed_nodes = set()

op = node['op']
input_indices = op_input_indices(op)
l_shape = op.inputs[0].shape
Expand Down Expand Up @@ -1195,13 +1198,18 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int:
prev_output_indices = []
num_constant_nodes = 0
prev_hints = set()
skip = False
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))

if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
if prev_node['name'] in processed_nodes:
skip = True
break
pending_processed_nodes.add(prev_node['name'])
if mode == 'down':
perm = tuple(prev_node['op'].inputs[1].tensor.tolist())
cand_perms.setdefault(perm, 0)
Expand All @@ -1216,7 +1224,7 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int:
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
num_constant_nodes += 1

if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints):
continue

next_nodes = []
Expand All @@ -1231,6 +1239,10 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int:
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
else:
if next_node['name'] in processed_nodes:
skip = True
break
pending_processed_nodes.add(next_node['name'])
next_nodes.append(next_node)
next_edges.append(edge)

Expand All @@ -1246,7 +1258,7 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int:
if 'direction' in next_node['op'].extra_hints:
next_hints.add(next_node['op'].extra_hints['direction'])

if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints):
continue

cur_transpose_size = sum(cand_perms.values()) + sum(cand_rev_perms.values())
Expand All @@ -1261,6 +1273,9 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int:
if 'down' in prev_hints or 'up' in next_hints:
skip = False

if skip:
continue

perm = max(cand_perms.items(), key=lambda x: x[1])[0]
perm_arr = np.array(perm, dtype='int32')

Expand All @@ -1285,6 +1300,9 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int:
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])

for pending_processed_node in pending_processed_nodes:
processed_nodes.add(pending_processed_node)

for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
Expand Down

0 comments on commit f19e3db

Please sign in to comment.