Skip to content

Commit

Permalink
[converter] fuse bn & conv (#319)
Browse files Browse the repository at this point in the history
* [converter] fuse bn & conv

* fix lint
  • Loading branch information
peterjc123 authored May 29, 2024
1 parent dfb8c5f commit 8c1f2ce
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,68 @@ def fuse_conv_fc_bn(self):
assert vertex['node_type'] == ExtendedOperator.BATCH_NORM
self.graph.graph.delete_vertices(remove_ids)

@class_conditional(lambda self: self.level >= GraphOptimizer.FUSE_BN)
def fuse_bn_conv(self):
edges = self.graph.graph.es.select(functools.partial(is_rev_bn_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)

def _remove_last_pred(seq):
bn = seq[0]
conv = seq[1]

# Collect the arguments of the conv and batch-norm nodes
weight = conv['op'].inputs[1]
bias = conv['op'].inputs[2] if len(conv['op'].inputs) > 2 else None
bn_w, bn_b, bn_mean, bn_var = bn['op'].inputs[1:]
bn_w, bn_b, bn_mean, bn_var = (
bn_w.tensor.copy(),
bn_b.tensor.copy(),
bn_mean.tensor.copy(),
bn_var.tensor.copy(),
)
activ_w = weight.tensor.copy()
activ_b = bias.tensor.copy() if bias is not None else None
eps = bn['op'].eps

new_weight = fuse_rev_bn_weight(eps, bn_w, bn_var, activ_w)
new_bias = fuse_rev_bn_bias(eps, bn_w, bn_var, bn_mean, bn_b, activ_b, activ_w)

return False, (conv, bias, new_weight, new_bias)

def _remove_last_action(first_node, last_node, custom_data):
conv, bias, new_weight, new_bias = custom_data

new_w = self.create_attr_tensor(new_weight)
new_b = self.create_attr_tensor(new_bias)

actions = []
actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
if bias is not None:
actions.append((self.graph.replace_operator_input, (conv, 2, new_b)))
else:
actions.append((self.graph.append_operator_input, (conv, new_b)))
return actions

def _skip_pred(seq):
bn = seq[0]['op']
conv = seq[1]['op']

skip = bn.inputs[0].quantization is not None or (
conv.inputs[1].shape[1] == 1 and conv.inputs[1].shape[0] == conv.groups and conv.groups > 1
)
return skip

elinimate_sequences(
self.graph,
filtered_pairs,
True,
None,
_remove_last_pred,
_remove_last_action,
_skip_pred,
force_forward_input=True,
)

@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_activation(self):
# Find fusable ops
Expand Down Expand Up @@ -3325,6 +3387,7 @@ def optimize(self):
self.fuse_conv_fc_bn()
self.fuse_activation()
self.fuse_requantize()
self.fuse_bn_conv()

# Convert TinyNeuralNetwork ops to TFLite ops
self.transform_graph()
Expand Down Expand Up @@ -3433,6 +3496,20 @@ def is_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
)


def is_rev_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
target_vertex['node_type'] == ExtendedOperator.GENERIC_CONV
and source_vertex['node_type'] == ExtendedOperator.BATCH_NORM
and source_vertex.outdegree() == 1
and source_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].inputs[2].buffer is not None
and target_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE
)


def is_padding_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
Expand Down Expand Up @@ -4123,6 +4200,33 @@ def fuse_bn_bias(eps, scale, var, mean, bn_b, activ_b):
return (-mean) * inv * scale + bn_b


def fuse_rev_bn_weight(eps, scale, var, weight):
shape = [1, -1] + [1] * (len(weight.shape) - 2)

inv = 1 / np.sqrt(var + eps)

return weight * (scale * inv).reshape(shape)


def fuse_rev_bn_bias(eps, scale, var, mean, bn_b, activ_b, weight):
reduced_dims = tuple([i for i in range(len(weight.shape)) if i > 1])

inv = 1 / np.sqrt(var + eps)
fused_b = bn_b - mean * inv * scale

if weight.shape[1] == 1 and mean.shape[0] > 1:
offset_b = (weight.sum(reduced_dims) * fused_b.reshape(-1, 1)).reshape(-1)
else:
offset_b = np.matmul(weight.sum(reduced_dims), fused_b.reshape(-1, 1)).reshape(-1)

if activ_b is not None:
if activ_b.shape != mean.shape and activ_b.ndim == 1 and activ_b.size == 1:
activ_b = activ_b.repeat(mean.size)
return activ_b + offset_b
else:
return offset_b


def fuse_slices(seq: typing.Iterable[ig.Vertex]):
cur_start = None
cur_end = None
Expand Down

0 comments on commit 8c1f2ce

Please sign in to comment.