diff --git a/tinynn/graph/quantization/quantizer.py b/tinynn/graph/quantization/quantizer.py index 460f9b1a..03b12d32 100644 --- a/tinynn/graph/quantization/quantizer.py +++ b/tinynn/graph/quantization/quantizer.py @@ -1788,7 +1788,7 @@ def _is_sub_node(node: TraceNode, custom_data): with override_current_trace_graph(graph): trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1) - graph.insert_between(input_node, node, trace_func, [output_tensor]) + graph.insert_between(input_node, node, trace_func, [output_tensor], True) node.module.func_type = '__add__' node.module.kind = 'add' @@ -1814,7 +1814,7 @@ def _is_sub_node(node: TraceNode, custom_data): with override_current_trace_graph(graph): trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1) - graph.insert_between(input_node, node, trace_func, [output_tensor]) + graph.insert_between(input_node, node, trace_func, [output_tensor], True) node.module.func_type = '__radd__' node.module.kind = 'add' @@ -2333,7 +2333,7 @@ def _is_partially_quantizable(node, custom_data): shared_tensors[0] ) next_tensors = [x.contiguous() for x in shared_tensors] - graph.insert_between(n, node, trace_func, next_tensors) + graph.insert_between(n, node, trace_func, next_tensors, True) # Remove non-leaf `.data` nodes def _is_non_leaf_data_nodes(node, custom_data): @@ -2474,7 +2474,7 @@ def _avgpool_kernel_size_and_stride(kernel_size, stride=None, *args, **kwargs): node_bn1d.next_tensors[0].unsqueeze_(2).unsqueeze_(2) prev_out = node_fc.prev_tensors[0][..., None, None] - graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out]) + graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out], True) next_out = torch.flatten(node_bn1d.next_tensors[0], 1) graph.insert_after(node_bn1d, next_func, [next_out]) @@ -2524,7 +2524,7 @@ def _is_batch_norm_1d(node, custom_data): node.next_tensors[0].unsqueeze_(2) prev_out = torch.unsqueeze(node.prev_tensors[0], 2) - graph.insert_between(node.prev_nodes[0], node, prev_func, [prev_out]) + graph.insert_between(node.prev_nodes[0], node, prev_func, [prev_out], True) next_out = torch.squeeze(node.next_tensors[0], 2) graph.insert_after(node, next_func, [next_out]) @@ -2936,7 +2936,7 @@ def _is_broadcastable_binary_quantized_op_node(node: TraceNode, custom_data) -> node.prev_tensors[src_index], node.prev_tensors[ref_index] ) next_tensors = [node.prev_tensors[src_index].expand_as(node.prev_tensors[ref_index])] - graph.insert_between(node.prev_nodes[src_index], node, trace_func, next_tensors) + graph.insert_between(node.prev_nodes[src_index], node, trace_func, next_tensors, True) new_node = graph.nodes_map[trace_func.unique_name] new_node.prev_nodes.append(node.prev_nodes[ref_index])