Skip to content

Commit

Permalink
[converter] add fuse_gather_conv2d_pass
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 committed Aug 28, 2024
1 parent e9ea15f commit 2c633d1
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/converter_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,36 @@ def forward(self, x):
self.assertEqual(tfl_model.Subgraphs(0).InputsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).OutputsLength(), 1)

def test_gather_conv2d(self):
class TestModel(nn.Module):
def __init__(self, with_bias=False):
super(TestModel, self).__init__()
self.block = nn.Sequential(
nn.PixelUnshuffle(2),
nn.Conv2d(8, 4, 3, 1, 1, bias=with_bias),
)

def forward(self, x):
return self.block(x)

model = TestModel()
model.eval()

dummy_input = torch.randn(1, 2, 512, 512)
model_path = get_model_path()

nchw_transpose = False
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=nchw_transpose)
converter.convert()

tfl_model = parse_model(model_path)
self.assertEqual(tfl_model.OperatorCodesLength(), 4)
self.assertEqual(tfl_model.OperatorCodes(1).DeprecatedBuiltinCode(), tflite.BuiltinOperator.SPACE_TO_DEPTH)
self.assertEqual(tfl_model.OperatorCodes(2).DeprecatedBuiltinCode(), tflite.BuiltinOperator.CONV_2D)
self.assertEqual(tfl_model.SubgraphsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).InputsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).OutputsLength(), 1)

def test_lower_transpose_dim_pass(self):
class TestModel(nn.Module):
def forward(self, x):
Expand Down
56 changes: 56 additions & 0 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,46 @@ def fuse_conv2d_gather(self):
# Delete activation nodes
self.graph.graph.delete_vertices(remove_ids)

@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_gather_conv2d(self):
# Find fusable ops
edges = self.graph.graph.es.select(functools.partial(is_gather_conv2d_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)

def _remove_last_pred(seq):
gather = seq[0]
conv = seq[1]
return False, (gather, conv)

def _remove_last_action(first_node, last_node, custom_data):
gather, conv = custom_data

actions = []

indx = np.argsort(gather['op'].inputs[1].tensor)
w = conv['op'].inputs[1].tensor.copy()
w_quant_param = conv['op'].inputs[1].quantization
new_w = np.take(w, indx, axis=3)
if w_quant_param is not None and isinstance(w_quant_param.scale, list) and w_quant_param.dim == 3:
new_w_scale = np.take(w_quant_param.scale, indx, axis=0)
new_w_zeros = np.take(w_quant_param.zero_point, indx, axis=0)
w_quant_param.scale = new_w_scale
w_quant_param.zero_point = new_w_zeros
new_w = self.create_attr_tensor(new_w, quantization=w_quant_param)
actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
return actions

elinimate_sequences(
self.graph,
filtered_pairs,
True,
None,
_remove_last_pred,
_remove_last_action,
False,
force_forward_input=True,
)

@class_conditional(lambda self: self.tflite_micro_rewrite)
def split_requantize(self):
vertices = self.graph.graph.vs.select(functools.partial(is_requantize_node, graph_converter=self.graph.graph))
Expand Down Expand Up @@ -3576,6 +3616,8 @@ def optimize(self):
self.fuse_same_padding()
self.fuse_same_padding_slicing()

self.fuse_gather_conv2d()

# Group conv & deconv
self.group_conv_rewrite_pass()
self.group_deconv_rewrite_pass()
Expand Down Expand Up @@ -4228,6 +4270,20 @@ def is_conv2d_gather_edge(edge: ig.Edge, graph_converter: ig.Graph):
)


def is_gather_conv2d_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]

return (
source_vertex['node_type'] == ExtendedOperator.GATHER
and target_vertex['node_type'] == ExtendedOperator.CONV_2D
and source_vertex.outdegree() == 1
and source_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].axis == 3
and source_vertex['op'].inputs[1].tensor.shape[0] == target_vertex['op'].inputs[1].tensor.shape[3]
)


def is_reciprocal_sqrt_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
Expand Down

0 comments on commit 2c633d1

Please sign in to comment.