From bfb723aabba0db08890fa17f543ecc38bd6e365b Mon Sep 17 00:00:00 2001 From: marty1885 Date: Mon, 25 Mar 2024 14:54:54 +0000 Subject: [PATCH] optimized: reduce time complexity of node replacment --- python/tvm/relay/frontend/pytorch.py | 98 +++++++++++++++------------- 1 file changed, 53 insertions(+), 45 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5e60faf16..4499884a3 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -293,7 +293,7 @@ def make_elemwise(self, name): def elemwise(inputs, input_types): if name == "divide": # https://pytorch.org/docs/stable/generated/torch.div.html#torch.div - # None - default behavior. Performs no rounding and, if both input and + # None - default behavior. Performs no rounding and, if both input and # other are integer types, promotes the inputs to the default scalar type. if all(["int" in input_type for input_type in input_types[:2]]): input_types[:2] = ["float32"] * 2 @@ -744,7 +744,7 @@ def tensordot(self, input, input_types): y = input[1] xshape = self.infer_shape(x) yshape = self.infer_shape(y) - + # handle all types of inputs for `dims` if isinstance(dims, int): pairs = [] @@ -773,7 +773,7 @@ def tensordot(self, input, input_types): for j in range(len(dims[0])): if dims[0][j] < 0: dims[0][j] += len(xshape) - + dims[1] = list(dims[1]) for j in range(len(dims[1])): if dims[1][j] < 0: @@ -793,7 +793,7 @@ def tensordot(self, input, input_types): dim_to_char = OrderedDict() dim_to_char[0] = OrderedDict() dim_to_char[1] = OrderedDict() - + x_str = "" for i, j in enumerate(xshape): @@ -801,7 +801,7 @@ def tensordot(self, input, input_types): dim_to_char[0][i] = alphabet[l] l += 1 x_str = x_str + dim_to_char[0][i] - + y_str = "" for i, j in enumerate(yshape): @@ -990,7 +990,7 @@ def fill(self, inputs, input_type): def full(self, inputs, input_types): data = inputs[0] fill_value = inputs[1] - + # Convert to scaler if provided values is TVM call expression and is not dependent on any inputs (i.e. constant) fill_value = _infer_value(fill_value, {}).numpy().item() if type(fill_value) == _expr.Call and len(_analysis.free_vars(fill_value)) == 0 else fill_value if isinstance(fill_value, tvm.relay.expr.TupleGetItem): @@ -1389,7 +1389,7 @@ def conv2d(self, inputs, input_types): # Add no output padding and move groups from inputs[6] to inputs[8] inputs.append([0, 0]) inputs.append(inputs[6]) - inputs[6] = 0 + inputs[6] = 0 return self.convolution(inputs, input_types) def convolution(self, inputs, input_types): @@ -1812,7 +1812,7 @@ def view(self, inputs, input_types): if isinstance(data, _expr.Constant): old_shape = data.data.shape num_new_dims = len(new_shape) - len(old_shape) - + if num_new_dims > 1: data = _op.transform.expand_dims(data, -1, num_new_dims) return _op.transform.reshape(data, new_shape) @@ -2543,7 +2543,7 @@ def broadcast_tensors(self, inputs, input_types): def broadcast_to(self, inputs, input_types): tensor_list = inputs[1] - + if type(tensor_list) is list: res_shape = tensor_list else: @@ -2689,7 +2689,7 @@ def embedding_bag(self, inputs, input_types): take = [] take.append(_op.embedding(weight, indices.astype("int32"), axis=0)) - + if mode == "sum": out = _op.sum(take[0], axis=0, keepdims=True) elif mode == "mean": @@ -2721,7 +2721,7 @@ def index(self, inputs, input_types): if indices[0] == None: # Remove first None argument (represents ':') indices.pop(0) - + assert len(_infer_shape(data)) == 2 and len(indices) == 1, "Currently supportes only 2D tensors with single mask" indices = indices[0] @@ -2796,7 +2796,7 @@ def index(self, inputs, input_types): # Extract indices from boolean mask indices = _op.transform.argwhere(indices) # Doing this reshape to remove dynamic shapes caused by argwhere op (e.g. '?' shapes). This - # reshape will ensure that the output of argwhere (and following ops) is "predictable" in a + # reshape will ensure that the output of argwhere (and following ops) is "predictable" in a # manner to suport further TVM compilation. However, this is only valid if this op is fallback # on CPU. Otherwise, this reshape will cause incorrect results. # indices = _op.reshape(indices, newshape=_infer_shape(data)[0]) @@ -3040,19 +3040,19 @@ def index_put(self, inputs, input_types): mode = "add" # Combine array of index tensors into one index tensor with shape (N,_) index_tensor = _op.stack(indices, axis=0) - + # Narrow index tensor to match input tensor shape_diff = len(_infer_shape(index_tensor)) - len(_infer_shape(in_tensor)) if shape_diff > 0: for shape in range(shape_diff): index_tensor = _op.squeeze(index_tensor, axis=[0]) - + # If indexes are in form of boolean mask instead of indices, use where op # instead of scatter_nd if _infer_type(index_tensor).checked_type.dtype == "bool": if isinstance(values, float): values = _expr.const(values, dtype=_infer_type(in_tensor).checked_type.dtype) - + # Make sure that dynamic output will be 1D vector # index_tensor = _op.reshape(index_tensor, newshape=(-1,)) # Make sure that dynamic output will be 1D vector @@ -3065,13 +3065,13 @@ def index_put(self, inputs, input_types): indices = _op.transform.argwhere(index_tensor) indices = _op.transpose(indices, (1, 0)) indices = _op.squeeze(indices, _expr.const([0])) if len(_infer_shape(indices)) == 2 and _infer_shape(indices)[0] == 1 else indices - + # Make sure that dynamic output will be 1D vector values = _op.reshape(values, newshape=(-1,)) - + # Reduce data to 1D vector if possible in_tensor = _op.reshape(in_tensor, newshape=(-1,)) - + res = _op.scatter_elements(in_tensor, indices, values, 0, "add") return res @@ -3322,11 +3322,11 @@ def replace_inf(inp, replacement_val=1e4): value = _op.broadcast_to_like(value, mask) one_const = _expr.const(1, dtype="float32") - + # Original implementation # return _op.where(mask, value, inputs[0]) - - # Implementaiton without using where operator in order to avoide numerical instability + + # Implementaiton without using where operator in order to avoide numerical instability # for certain models caused by the future matmul (once where is decomposed) return _op.add(_op.multiply(inputs[0], _op.subtract(one_const, mask)), _op.multiply(value, mask)) @@ -4054,7 +4054,7 @@ def all_any_common(self, op, inputs, input_types): dim = inputs[1] else: dim = 0 - + if len(inputs) > 2: keepdim = inputs[2] else: @@ -4273,7 +4273,7 @@ def tril(self, inputs, input_types): y = np.tril(np.ones(x_shape)).astype(_convert_tvm_to_np_dtype(input_types[0])) y = tvm.nd.array(y) y = tvm.relay.Constant(y) - + return _op.multiply(x, y) @@ -4288,7 +4288,7 @@ def triu(self, inputs, input_types): zeros = np.zeros(x_shape).astype(_convert_tvm_to_np_dtype(input_types[0])) zeros = tvm.nd.array(zeros) zeros = tvm.relay.Constant(zeros) - + return _op.where(mask, x, zeros) @@ -4330,7 +4330,7 @@ def as_strided(self, inputs, input_types): rc_begin += stride_col rc_end = rc_begin + (n_out_col * stride_col) - + rc_rows = _op.concatenate(rc_rows, axis=0) rc_rows = _op.expand_dims(rc_rows, axis=0) time_rows = np.append(time_rows, rc_rows) @@ -4338,7 +4338,7 @@ def as_strided(self, inputs, input_types): time_rows = _op.concatenate(time_rows, axis=0) time_rows = _op.expand_dims(time_rows, axis=0) batch_rows = np.append(batch_rows, time_rows) - + return _op.concatenate(batch_rows, axis=0) @@ -4388,7 +4388,7 @@ def alias(self, inputs, inputs_types): # Get constant dtype dtype = _convert_data_type(shape.data.dtype, default_dtype="float32") - + # Convert to numpy array shape = shape.data.numpy() if len(shape.shape) == 0: @@ -4464,7 +4464,7 @@ def scaled_dot_product_attention(self, inputs, input_types): scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=dtype) scale_factor = _op.broadcast_to(scale_factor, shape=tuple(1 for _ in range(len(query_shape)))) - + # Early out if not decomposing return _op.nn.scaled_dot_product_attention( query, @@ -4505,7 +4505,7 @@ def scaled_dot_product_attention(self, inputs, input_types): batch_size = key_shape[0] else: batch_size = query_shape[0] - + if len(query_shape) == 4 and len(key_shape) == 4: query = _op.reshape(query, newshape=[-3, -2]) key = _op.reshape(key, newshape=[-3, -2]) @@ -5650,11 +5650,11 @@ def get_relay_ty(ishape, itype, pt_type): input_vars = {} - + def get_new_input_infos(input_infos): new_input_infos = [] for num, inp in enumerate(input_infos): - + if not isinstance(inp, tuple): msg = "Graph input {} is not a tuple".format(num) raise RuntimeError(msg) @@ -5663,7 +5663,7 @@ def get_new_input_infos(input_infos): "Graph input {} is not valid," " expected ('name', shape) or ('name', (shape, dtype))".format(inp) ) - + raise RuntimeError(msg) if isinstance(inp[1], (list, tuple)) and isinstance(inp[1][0], (list, tuple)) and isinstance(inp[1][0][0], str): new_input_infos.append((inp[0], get_new_input_infos(inp[1]))) @@ -5672,9 +5672,9 @@ def get_new_input_infos(input_infos): else: new_input_infos.append(inp) return new_input_infos - + new_input_infos = get_new_input_infos(input_infos) - + def get_input_types(input_infos, graph_input_types): input_types = [] for (name, info), gi_type in zip(input_infos, graph_input_types): @@ -5685,11 +5685,11 @@ def get_input_types(input_infos, graph_input_types): input_types.append((name, get_relay_ty(info[0], info[1], gi_type), info[1])) # info[1] is the framework datatype, which may differ after being converted to relay return input_types - + graph_input_types = [gi.type() for gi in graph_inputs] input_types = get_input_types(new_input_infos, graph_input_types) - def get_input_vars(input_types, graph_input_names, use_tuple_type=False, tuple_name=""): + def get_input_vars(input_types, graph_input_names, use_tuple_type=False, tuple_name=""): input_vars = {} if not use_tuple_type else [] for gi_name, gi_type in zip(graph_input_names, input_types): name, itype = gi_type[0], gi_type[1] @@ -5828,7 +5828,7 @@ def convert_params(graph, state_dict, source_map, use_parser_friendly_name=False elif full_attr in state_dict: if var_name in vars_by_name: var = vars_by_name[var_name] - # we need to remap inputs that pointed to the old + # we need to remap inputs that pointed to the old input_remap[full_attr_node_name] = outputs_by_var_name[var_name] else: torch_tensor = state_dict[full_attr] @@ -5871,7 +5871,7 @@ def export_c_graph(location, graph): fname = os.path.join(location, f"tvm_exported_c_graph_{time_stamp}.txt") with open(f"{fname}", "w") as f: f.write(str(graph)) - + def outplace_inplace_ops(opnodes): replace_map = [] @@ -5880,18 +5880,26 @@ def outplace_inplace_ops(opnodes): for i, (node_name, op_node) in enumerate(opnodes): operator = op_node.kind() # Check if op is in-place (avoid '__not__', etc.) - if operator[-1] == '_' and operator[-2:] != "__": + if operator[-1] == '_' and operator[-2:] != "__": input_node = op_node.inputsAt(0) replace_map.append((i, input_node, op_node.outputsAt(0))) # Replace future uses of node with an in-place op applied to it with the output of the op - for node_idx, orig_node, replacement_node in replace_map: - relevant_ops = opnodes[node_idx+1:] + node_inputs_map = {} + for idx, (node_name, op_node) in enumerate(opnodes): + for inp in op_node.inputs(): + if inp not in node_inputs_map: + node_inputs_map[inp] = [] + node_inputs_map[inp].append((idx, op_node)) - for _, node in relevant_ops: - if orig_node in node.inputs(): - node.replaceInputWith(orig_node, replacement_node) + for node_idx, orig_node, replacement_node in replace_map: + if orig_node not in node_inputs_map: + continue + relevant_ops = node_inputs_map[orig_node] + begin_idx = _binray_search(relevant_ops, lambda x: x[0] > node_idx) + for idx, node in relevant_ops[begin_idx:]: + node.replaceInputWith(orig_node, replacement_node) def from_pytorch( script_module, @@ -5901,7 +5909,7 @@ def from_pytorch( use_parser_friendly_name=False, keep_quantized_weight=False, export_renamed_c_graph_path=None, - do_convert_params=True, + do_convert_params=True, ): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically.