Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[converter] fuse bn & conv #319

Merged
merged 2 commits into from
May 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,68 @@ def fuse_conv_fc_bn(self):
assert vertex['node_type'] == ExtendedOperator.BATCH_NORM
self.graph.graph.delete_vertices(remove_ids)

@class_conditional(lambda self: self.level >= GraphOptimizer.FUSE_BN)
def fuse_bn_conv(self):
edges = self.graph.graph.es.select(functools.partial(is_rev_bn_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)

def _remove_last_pred(seq):
bn = seq[0]
conv = seq[1]

# Collect the arguments of the conv and batch-norm nodes
weight = conv['op'].inputs[1]
bias = conv['op'].inputs[2] if len(conv['op'].inputs) > 2 else None
bn_w, bn_b, bn_mean, bn_var = bn['op'].inputs[1:]
bn_w, bn_b, bn_mean, bn_var = (
bn_w.tensor.copy(),
bn_b.tensor.copy(),
bn_mean.tensor.copy(),
bn_var.tensor.copy(),
)
activ_w = weight.tensor.copy()
activ_b = bias.tensor.copy() if bias is not None else None
eps = bn['op'].eps

new_weight = fuse_rev_bn_weight(eps, bn_w, bn_var, activ_w)
new_bias = fuse_rev_bn_bias(eps, bn_w, bn_var, bn_mean, bn_b, activ_b, activ_w)

return False, (conv, bias, new_weight, new_bias)

def _remove_last_action(first_node, last_node, custom_data):
conv, bias, new_weight, new_bias = custom_data

new_w = self.create_attr_tensor(new_weight)
new_b = self.create_attr_tensor(new_bias)

actions = []
actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
if bias is not None:
actions.append((self.graph.replace_operator_input, (conv, 2, new_b)))
else:
actions.append((self.graph.append_operator_input, (conv, new_b)))
return actions

def _skip_pred(seq):
bn = seq[0]['op']
conv = seq[1]['op']

skip = bn.inputs[0].quantization is not None or (
conv.inputs[1].shape[1] == 1 and conv.inputs[1].shape[0] == conv.groups and conv.groups > 1
)
return skip

elinimate_sequences(
self.graph,
filtered_pairs,
True,
None,
_remove_last_pred,
_remove_last_action,
_skip_pred,
force_forward_input=True,
)

@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_activation(self):
# Find fusable ops
Expand Down Expand Up @@ -3325,6 +3387,7 @@ def optimize(self):
self.fuse_conv_fc_bn()
self.fuse_activation()
self.fuse_requantize()
self.fuse_bn_conv()

# Convert TinyNeuralNetwork ops to TFLite ops
self.transform_graph()
Expand Down Expand Up @@ -3433,6 +3496,20 @@ def is_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
)


def is_rev_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
target_vertex['node_type'] == ExtendedOperator.GENERIC_CONV
and source_vertex['node_type'] == ExtendedOperator.BATCH_NORM
and source_vertex.outdegree() == 1
and source_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].inputs[2].buffer is not None
and target_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE
)


def is_padding_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
Expand Down Expand Up @@ -4123,6 +4200,33 @@ def fuse_bn_bias(eps, scale, var, mean, bn_b, activ_b):
return (-mean) * inv * scale + bn_b


def fuse_rev_bn_weight(eps, scale, var, weight):
shape = [1, -1] + [1] * (len(weight.shape) - 2)

inv = 1 / np.sqrt(var + eps)

return weight * (scale * inv).reshape(shape)


def fuse_rev_bn_bias(eps, scale, var, mean, bn_b, activ_b, weight):
reduced_dims = tuple([i for i in range(len(weight.shape)) if i > 1])

inv = 1 / np.sqrt(var + eps)
fused_b = bn_b - mean * inv * scale

if weight.shape[1] == 1 and mean.shape[0] > 1:
offset_b = (weight.sum(reduced_dims) * fused_b.reshape(-1, 1)).reshape(-1)
else:
offset_b = np.matmul(weight.sum(reduced_dims), fused_b.reshape(-1, 1)).reshape(-1)

if activ_b is not None:
if activ_b.shape != mean.shape and activ_b.ndim == 1 and activ_b.size == 1:
activ_b = activ_b.repeat(mean.size)
return activ_b + offset_b
else:
return offset_b


def fuse_slices(seq: typing.Iterable[ig.Vertex]):
cur_start = None
cur_end = None
Expand Down
Loading