Skip to content

Commit

Permalink
[tracer] apply move_idx=True for graph.insert_between (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored Oct 10, 2024
1 parent a2acdf4 commit 75fc200
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,7 +1788,7 @@ def _is_sub_node(node: TraceNode, custom_data):

with override_current_trace_graph(graph):
trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1)
graph.insert_between(input_node, node, trace_func, [output_tensor])
graph.insert_between(input_node, node, trace_func, [output_tensor], True)

node.module.func_type = '__add__'
node.module.kind = 'add'
Expand All @@ -1814,7 +1814,7 @@ def _is_sub_node(node: TraceNode, custom_data):

with override_current_trace_graph(graph):
trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1)
graph.insert_between(input_node, node, trace_func, [output_tensor])
graph.insert_between(input_node, node, trace_func, [output_tensor], True)

node.module.func_type = '__radd__'
node.module.kind = 'add'
Expand Down Expand Up @@ -2333,7 +2333,7 @@ def _is_partially_quantizable(node, custom_data):
shared_tensors[0]
)
next_tensors = [x.contiguous() for x in shared_tensors]
graph.insert_between(n, node, trace_func, next_tensors)
graph.insert_between(n, node, trace_func, next_tensors, True)

# Remove non-leaf `.data` nodes
def _is_non_leaf_data_nodes(node, custom_data):
Expand Down Expand Up @@ -2474,7 +2474,7 @@ def _avgpool_kernel_size_and_stride(kernel_size, stride=None, *args, **kwargs):
node_bn1d.next_tensors[0].unsqueeze_(2).unsqueeze_(2)

prev_out = node_fc.prev_tensors[0][..., None, None]
graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out])
graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out], True)
next_out = torch.flatten(node_bn1d.next_tensors[0], 1)
graph.insert_after(node_bn1d, next_func, [next_out])

Expand Down Expand Up @@ -2524,7 +2524,7 @@ def _is_batch_norm_1d(node, custom_data):
node.next_tensors[0].unsqueeze_(2)

prev_out = torch.unsqueeze(node.prev_tensors[0], 2)
graph.insert_between(node.prev_nodes[0], node, prev_func, [prev_out])
graph.insert_between(node.prev_nodes[0], node, prev_func, [prev_out], True)
next_out = torch.squeeze(node.next_tensors[0], 2)
graph.insert_after(node, next_func, [next_out])

Expand Down Expand Up @@ -2936,7 +2936,7 @@ def _is_broadcastable_binary_quantized_op_node(node: TraceNode, custom_data) ->
node.prev_tensors[src_index], node.prev_tensors[ref_index]
)
next_tensors = [node.prev_tensors[src_index].expand_as(node.prev_tensors[ref_index])]
graph.insert_between(node.prev_nodes[src_index], node, trace_func, next_tensors)
graph.insert_between(node.prev_nodes[src_index], node, trace_func, next_tensors, True)

new_node = graph.nodes_map[trace_func.unique_name]
new_node.prev_nodes.append(node.prev_nodes[ref_index])
Expand Down

0 comments on commit 75fc200

Please sign in to comment.