Skip to content

Commit

Permalink
[converter] add fuse_simple_gather_pass (#351)
Browse files Browse the repository at this point in the history
* [converter] add fuse_simple_gather_pass

* more checks
  • Loading branch information
peterjc123 authored Aug 28, 2024
1 parent 6db6ba3 commit e9ea15f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
26 changes: 26 additions & 0 deletions tests/converter_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2594,6 +2594,32 @@ def forward(self, x):
self.assertEqual(tfl_model.Subgraphs(0).OperatorsLength(), 3)
self.assertEqual(tfl_model.Subgraphs(0).Operators(0).OutputsLength(), 1)

def test_consecutive_gather(self):
class TestModel(nn.Module):
def forward(self, x):
x = x.relu()
x = x[..., [2, 1, 0]]
x = x[..., [2, 1, 0]]
return x

model = TestModel()
model.eval()

dummy_input = torch.randn(1, 2, 3)
model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

tfl_model = parse_model(model_path)
self.assertEqual(tfl_model.OperatorCodesLength(), 1)
self.assertEqual(tfl_model.OperatorCodes(0).DeprecatedBuiltinCode(), tflite.BuiltinOperator.RELU)
self.assertEqual(tfl_model.SubgraphsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).InputsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).OutputsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).OperatorsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).Operators(0).OutputsLength(), 1)


if __name__ == '__main__':
unittest.main()
65 changes: 64 additions & 1 deletion tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,54 @@ def _remove_first_action(first_node, last_node, custom_data):

elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)

@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_simple_gather_pass(self):
edges = self.graph.graph.es.select(functools.partial(is_gather_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]

# Try to fuse the edges
filtered_pairs = fuse_connected_edges(filtered_pairs)

def _remove_first_pred(seq):
new_perm = fuse_transpose_perms(seq)

hints = set()
for node in seq:
if 'direction' in node['op'].extra_hints:
hints.add(node['op'].extra_hints['direction'])

if len(hints) == 1:
hint = next(iter(hints))
else:
hint = None

remove_first = np.array_equal(new_perm, np.sort(new_perm))
return remove_first, (new_perm, hint)

def _remove_first_action(first_node, last_node, custom_data):
# Set fused perm to the first transpose node
new_perm, hint = custom_data
if hint is None:
if 'direction' in first_node['op'].extra_hints:
del first_node['op'].extra_hints['direction']
else:
first_node['op'].extra_hints['direction'] = hint
new_perm_tensor = self.create_attr_tensor(new_perm)
action = (self.graph.replace_operator_input, (first_node, 1, new_perm_tensor))
return [action]

def _skip_pred(seq):
for node in seq:
op = node['op']
idx_tensor = op.inputs[1]
if idx_tensor.buffer is None:
return True
if len(idx_tensor.shape) > 1:
return True
return False

elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action, skip_pred=_skip_pred)

@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_dequant_quant_pass(self, q_first):
edges = self.graph.graph.es.select(
Expand Down Expand Up @@ -3422,6 +3470,7 @@ def optimize(self):
self.fuse_simple_reshape_pass()
self.branch_transpose_expand_pass()
self.fuse_simple_transpose_pass()
self.fuse_simple_gather_pass()
for branch in (False, True):
self.remove_noop_pass(branch)
self.fuse_wrapped_reshape_within_transpose_pass()
Expand Down Expand Up @@ -4042,6 +4091,20 @@ def is_transpose_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
)


def is_gather_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.GATHER
and source_vertex.outdegree() == 1
and target_vertex['node_type'] == ExtendedOperator.GATHER
and target_vertex.outdegree() >= 1
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
and source_vertex['op'].axis == target_vertex['op'].axis
and source_vertex['op'].batchDims == target_vertex['op'].batchDims
)


def is_reshape_branch_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
Expand Down Expand Up @@ -4342,7 +4405,7 @@ def fuse_slices(seq: typing.Iterable[ig.Vertex]):
def fuse_transpose_perms(seq: typing.Iterable[ig.Vertex]):
cur_perm = None
for node in seq:
assert node['node_type'] == ExtendedOperator.TRANSPOSE
assert node['node_type'] in (ExtendedOperator.TRANSPOSE, ExtendedOperator.GATHER)
next_perm = node['op'].inputs[1].tensor
if cur_perm is None:
cur_perm = next_perm
Expand Down

0 comments on commit e9ea15f

Please sign in to comment.