diff --git a/tinynn/converter/operators/optimize.py b/tinynn/converter/operators/optimize.py index 5ac963e1..a8283e21 100644 --- a/tinynn/converter/operators/optimize.py +++ b/tinynn/converter/operators/optimize.py @@ -167,6 +167,68 @@ def fuse_conv_fc_bn(self): assert vertex['node_type'] == ExtendedOperator.BATCH_NORM self.graph.graph.delete_vertices(remove_ids) + @class_conditional(lambda self: self.level >= GraphOptimizer.FUSE_BN) + def fuse_bn_conv(self): + edges = self.graph.graph.es.select(functools.partial(is_rev_bn_fusable_edge, graph_converter=self.graph.graph)) + filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges) + + def _remove_last_pred(seq): + bn = seq[0] + conv = seq[1] + + # Collect the arguments of the conv and batch-norm nodes + weight = conv['op'].inputs[1] + bias = conv['op'].inputs[2] if len(conv['op'].inputs) > 2 else None + bn_w, bn_b, bn_mean, bn_var = bn['op'].inputs[1:] + bn_w, bn_b, bn_mean, bn_var = ( + bn_w.tensor.copy(), + bn_b.tensor.copy(), + bn_mean.tensor.copy(), + bn_var.tensor.copy(), + ) + activ_w = weight.tensor.copy() + activ_b = bias.tensor.copy() if bias is not None else None + eps = bn['op'].eps + + new_weight = fuse_rev_bn_weight(eps, bn_w, bn_var, activ_w) + new_bias = fuse_rev_bn_bias(eps, bn_w, bn_var, bn_mean, bn_b, activ_b, activ_w) + + return False, (conv, bias, new_weight, new_bias) + + def _remove_last_action(first_node, last_node, custom_data): + conv, bias, new_weight, new_bias = custom_data + + new_w = self.create_attr_tensor(new_weight) + new_b = self.create_attr_tensor(new_bias) + + actions = [] + actions.append((self.graph.replace_operator_input, (conv, 1, new_w))) + if bias is not None: + actions.append((self.graph.replace_operator_input, (conv, 2, new_b))) + else: + actions.append((self.graph.append_operator_input, (conv, new_b))) + return actions + + def _skip_pred(seq): + bn = seq[0]['op'] + conv = seq[1]['op'] + + skip = bn.inputs[0].quantization is not None or ( + conv.inputs[1].shape[1] == 1 and conv.inputs[1].shape[0] == conv.groups and conv.groups > 1 + ) + return skip + + elinimate_sequences( + self.graph, + filtered_pairs, + True, + None, + _remove_last_pred, + _remove_last_action, + _skip_pred, + force_forward_input=True, + ) + @class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE) def fuse_activation(self): # Find fusable ops @@ -3325,6 +3387,7 @@ def optimize(self): self.fuse_conv_fc_bn() self.fuse_activation() self.fuse_requantize() + self.fuse_bn_conv() # Convert TinyNeuralNetwork ops to TFLite ops self.transform_graph() @@ -3433,6 +3496,20 @@ def is_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph): ) +def is_rev_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph): + source_vertex = graph_converter.vs[edge.source] + target_vertex = graph_converter.vs[edge.target] + return ( + target_vertex['node_type'] == ExtendedOperator.GENERIC_CONV + and source_vertex['node_type'] == ExtendedOperator.BATCH_NORM + and source_vertex.outdegree() == 1 + and source_vertex['op'].inputs[1].buffer is not None + and source_vertex['op'].inputs[2].buffer is not None + and target_vertex['op'].inputs[1].buffer is not None + and source_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE + ) + + def is_padding_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph): source_vertex = graph_converter.vs[edge.source] target_vertex = graph_converter.vs[edge.target] @@ -4123,6 +4200,33 @@ def fuse_bn_bias(eps, scale, var, mean, bn_b, activ_b): return (-mean) * inv * scale + bn_b +def fuse_rev_bn_weight(eps, scale, var, weight): + shape = [1, -1] + [1] * (len(weight.shape) - 2) + + inv = 1 / np.sqrt(var + eps) + + return weight * (scale * inv).reshape(shape) + + +def fuse_rev_bn_bias(eps, scale, var, mean, bn_b, activ_b, weight): + reduced_dims = tuple([i for i in range(len(weight.shape)) if i > 1]) + + inv = 1 / np.sqrt(var + eps) + fused_b = bn_b - mean * inv * scale + + if weight.shape[1] == 1 and mean.shape[0] > 1: + offset_b = (weight.sum(reduced_dims) * fused_b.reshape(-1, 1)).reshape(-1) + else: + offset_b = np.matmul(weight.sum(reduced_dims), fused_b.reshape(-1, 1)).reshape(-1) + + if activ_b is not None: + if activ_b.shape != mean.shape and activ_b.ndim == 1 and activ_b.size == 1: + activ_b = activ_b.repeat(mean.size) + return activ_b + offset_b + else: + return offset_b + + def fuse_slices(seq: typing.Iterable[ig.Vertex]): cur_start = None cur_end = None