Skip to content

Commit

Permalink
optimized: reduce time complexity of node replacment
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 committed Mar 25, 2024
1 parent dfda382 commit bfb723a
Showing 1 changed file with 53 additions and 45 deletions.
98 changes: 53 additions & 45 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def make_elemwise(self, name):
def elemwise(inputs, input_types):
if name == "divide":
# https://pytorch.org/docs/stable/generated/torch.div.html#torch.div
# None - default behavior. Performs no rounding and, if both input and
# None - default behavior. Performs no rounding and, if both input and
# other are integer types, promotes the inputs to the default scalar type.
if all(["int" in input_type for input_type in input_types[:2]]):
input_types[:2] = ["float32"] * 2
Expand Down Expand Up @@ -744,7 +744,7 @@ def tensordot(self, input, input_types):
y = input[1]
xshape = self.infer_shape(x)
yshape = self.infer_shape(y)

# handle all types of inputs for `dims`
if isinstance(dims, int):
pairs = []
Expand Down Expand Up @@ -773,7 +773,7 @@ def tensordot(self, input, input_types):
for j in range(len(dims[0])):
if dims[0][j] < 0:
dims[0][j] += len(xshape)

dims[1] = list(dims[1])
for j in range(len(dims[1])):
if dims[1][j] < 0:
Expand All @@ -793,15 +793,15 @@ def tensordot(self, input, input_types):
dim_to_char = OrderedDict()
dim_to_char[0] = OrderedDict()
dim_to_char[1] = OrderedDict()


x_str = ""
for i, j in enumerate(xshape):
if i not in dim_to_char[0]:
dim_to_char[0][i] = alphabet[l]
l += 1
x_str = x_str + dim_to_char[0][i]


y_str = ""
for i, j in enumerate(yshape):
Expand Down Expand Up @@ -990,7 +990,7 @@ def fill(self, inputs, input_type):
def full(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]

# Convert to scaler if provided values is TVM call expression and is not dependent on any inputs (i.e. constant)
fill_value = _infer_value(fill_value, {}).numpy().item() if type(fill_value) == _expr.Call and len(_analysis.free_vars(fill_value)) == 0 else fill_value
if isinstance(fill_value, tvm.relay.expr.TupleGetItem):
Expand Down Expand Up @@ -1389,7 +1389,7 @@ def conv2d(self, inputs, input_types):
# Add no output padding and move groups from inputs[6] to inputs[8]
inputs.append([0, 0])
inputs.append(inputs[6])
inputs[6] = 0
inputs[6] = 0
return self.convolution(inputs, input_types)

def convolution(self, inputs, input_types):
Expand Down Expand Up @@ -1812,7 +1812,7 @@ def view(self, inputs, input_types):
if isinstance(data, _expr.Constant):
old_shape = data.data.shape
num_new_dims = len(new_shape) - len(old_shape)

if num_new_dims > 1:
data = _op.transform.expand_dims(data, -1, num_new_dims)
return _op.transform.reshape(data, new_shape)
Expand Down Expand Up @@ -2543,7 +2543,7 @@ def broadcast_tensors(self, inputs, input_types):

def broadcast_to(self, inputs, input_types):
tensor_list = inputs[1]

if type(tensor_list) is list:
res_shape = tensor_list
else:
Expand Down Expand Up @@ -2689,7 +2689,7 @@ def embedding_bag(self, inputs, input_types):

take = []
take.append(_op.embedding(weight, indices.astype("int32"), axis=0))

if mode == "sum":
out = _op.sum(take[0], axis=0, keepdims=True)
elif mode == "mean":
Expand Down Expand Up @@ -2721,7 +2721,7 @@ def index(self, inputs, input_types):
if indices[0] == None:
# Remove first None argument (represents ':')
indices.pop(0)

assert len(_infer_shape(data)) == 2 and len(indices) == 1, "Currently supportes only 2D tensors with single mask"

indices = indices[0]
Expand Down Expand Up @@ -2796,7 +2796,7 @@ def index(self, inputs, input_types):
# Extract indices from boolean mask
indices = _op.transform.argwhere(indices)
# Doing this reshape to remove dynamic shapes caused by argwhere op (e.g. '?' shapes). This
# reshape will ensure that the output of argwhere (and following ops) is "predictable" in a
# reshape will ensure that the output of argwhere (and following ops) is "predictable" in a
# manner to suport further TVM compilation. However, this is only valid if this op is fallback
# on CPU. Otherwise, this reshape will cause incorrect results.
# indices = _op.reshape(indices, newshape=_infer_shape(data)[0])
Expand Down Expand Up @@ -3040,19 +3040,19 @@ def index_put(self, inputs, input_types):
mode = "add"
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)

# Narrow index tensor to match input tensor
shape_diff = len(_infer_shape(index_tensor)) - len(_infer_shape(in_tensor))
if shape_diff > 0:
for shape in range(shape_diff):
index_tensor = _op.squeeze(index_tensor, axis=[0])

# If indexes are in form of boolean mask instead of indices, use where op
# instead of scatter_nd
if _infer_type(index_tensor).checked_type.dtype == "bool":
if isinstance(values, float):
values = _expr.const(values, dtype=_infer_type(in_tensor).checked_type.dtype)

# Make sure that dynamic output will be 1D vector
# index_tensor = _op.reshape(index_tensor, newshape=(-1,))
# Make sure that dynamic output will be 1D vector
Expand All @@ -3065,13 +3065,13 @@ def index_put(self, inputs, input_types):
indices = _op.transform.argwhere(index_tensor)
indices = _op.transpose(indices, (1, 0))
indices = _op.squeeze(indices, _expr.const([0])) if len(_infer_shape(indices)) == 2 and _infer_shape(indices)[0] == 1 else indices

# Make sure that dynamic output will be 1D vector
values = _op.reshape(values, newshape=(-1,))

# Reduce data to 1D vector if possible
in_tensor = _op.reshape(in_tensor, newshape=(-1,))

res = _op.scatter_elements(in_tensor, indices, values, 0, "add")

return res
Expand Down Expand Up @@ -3322,11 +3322,11 @@ def replace_inf(inp, replacement_val=1e4):
value = _op.broadcast_to_like(value, mask)

one_const = _expr.const(1, dtype="float32")

# Original implementation
# return _op.where(mask, value, inputs[0])
# Implementaiton without using where operator in order to avoide numerical instability

# Implementaiton without using where operator in order to avoide numerical instability
# for certain models caused by the future matmul (once where is decomposed)
return _op.add(_op.multiply(inputs[0], _op.subtract(one_const, mask)), _op.multiply(value, mask))

Expand Down Expand Up @@ -4054,7 +4054,7 @@ def all_any_common(self, op, inputs, input_types):
dim = inputs[1]
else:
dim = 0

if len(inputs) > 2:
keepdim = inputs[2]
else:
Expand Down Expand Up @@ -4273,7 +4273,7 @@ def tril(self, inputs, input_types):
y = np.tril(np.ones(x_shape)).astype(_convert_tvm_to_np_dtype(input_types[0]))
y = tvm.nd.array(y)
y = tvm.relay.Constant(y)

return _op.multiply(x, y)


Expand All @@ -4288,7 +4288,7 @@ def triu(self, inputs, input_types):
zeros = np.zeros(x_shape).astype(_convert_tvm_to_np_dtype(input_types[0]))
zeros = tvm.nd.array(zeros)
zeros = tvm.relay.Constant(zeros)

return _op.where(mask, x, zeros)


Expand Down Expand Up @@ -4330,15 +4330,15 @@ def as_strided(self, inputs, input_types):

rc_begin += stride_col
rc_end = rc_begin + (n_out_col * stride_col)

rc_rows = _op.concatenate(rc_rows, axis=0)
rc_rows = _op.expand_dims(rc_rows, axis=0)
time_rows = np.append(time_rows, rc_rows)

time_rows = _op.concatenate(time_rows, axis=0)
time_rows = _op.expand_dims(time_rows, axis=0)
batch_rows = np.append(batch_rows, time_rows)

return _op.concatenate(batch_rows, axis=0)


Expand Down Expand Up @@ -4388,7 +4388,7 @@ def alias(self, inputs, inputs_types):

# Get constant dtype
dtype = _convert_data_type(shape.data.dtype, default_dtype="float32")

# Convert to numpy array
shape = shape.data.numpy()
if len(shape.shape) == 0:
Expand Down Expand Up @@ -4464,7 +4464,7 @@ def scaled_dot_product_attention(self, inputs, input_types):

scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=dtype)
scale_factor = _op.broadcast_to(scale_factor, shape=tuple(1 for _ in range(len(query_shape))))

# Early out if not decomposing
return _op.nn.scaled_dot_product_attention(
query,
Expand Down Expand Up @@ -4505,7 +4505,7 @@ def scaled_dot_product_attention(self, inputs, input_types):
batch_size = key_shape[0]
else:
batch_size = query_shape[0]

if len(query_shape) == 4 and len(key_shape) == 4:
query = _op.reshape(query, newshape=[-3, -2])
key = _op.reshape(key, newshape=[-3, -2])
Expand Down Expand Up @@ -5650,11 +5650,11 @@ def get_relay_ty(ishape, itype, pt_type):

input_vars = {}


def get_new_input_infos(input_infos):
new_input_infos = []
for num, inp in enumerate(input_infos):

if not isinstance(inp, tuple):
msg = "Graph input {} is not a tuple".format(num)
raise RuntimeError(msg)
Expand All @@ -5663,7 +5663,7 @@ def get_new_input_infos(input_infos):
"Graph input {} is not valid,"
" expected ('name', shape) or ('name', (shape, dtype))".format(inp)
)

raise RuntimeError(msg)
if isinstance(inp[1], (list, tuple)) and isinstance(inp[1][0], (list, tuple)) and isinstance(inp[1][0][0], str):
new_input_infos.append((inp[0], get_new_input_infos(inp[1])))
Expand All @@ -5672,9 +5672,9 @@ def get_new_input_infos(input_infos):
else:
new_input_infos.append(inp)
return new_input_infos

new_input_infos = get_new_input_infos(input_infos)

def get_input_types(input_infos, graph_input_types):
input_types = []
for (name, info), gi_type in zip(input_infos, graph_input_types):
Expand All @@ -5685,11 +5685,11 @@ def get_input_types(input_infos, graph_input_types):
input_types.append((name, get_relay_ty(info[0], info[1], gi_type), info[1])) # info[1] is the framework datatype, which may differ after being converted to relay
return input_types


graph_input_types = [gi.type() for gi in graph_inputs]
input_types = get_input_types(new_input_infos, graph_input_types)

def get_input_vars(input_types, graph_input_names, use_tuple_type=False, tuple_name=""):
def get_input_vars(input_types, graph_input_names, use_tuple_type=False, tuple_name=""):
input_vars = {} if not use_tuple_type else []
for gi_name, gi_type in zip(graph_input_names, input_types):
name, itype = gi_type[0], gi_type[1]
Expand Down Expand Up @@ -5828,7 +5828,7 @@ def convert_params(graph, state_dict, source_map, use_parser_friendly_name=False
elif full_attr in state_dict:
if var_name in vars_by_name:
var = vars_by_name[var_name]
# we need to remap inputs that pointed to the old
# we need to remap inputs that pointed to the old
input_remap[full_attr_node_name] = outputs_by_var_name[var_name]
else:
torch_tensor = state_dict[full_attr]
Expand Down Expand Up @@ -5871,7 +5871,7 @@ def export_c_graph(location, graph):
fname = os.path.join(location, f"tvm_exported_c_graph_{time_stamp}.txt")
with open(f"{fname}", "w") as f:
f.write(str(graph))

def outplace_inplace_ops(opnodes):

replace_map = []
Expand All @@ -5880,18 +5880,26 @@ def outplace_inplace_ops(opnodes):
for i, (node_name, op_node) in enumerate(opnodes):
operator = op_node.kind()
# Check if op is in-place (avoid '__not__', etc.)
if operator[-1] == '_' and operator[-2:] != "__":
if operator[-1] == '_' and operator[-2:] != "__":
input_node = op_node.inputsAt(0)
replace_map.append((i, input_node, op_node.outputsAt(0)))

# Replace future uses of node with an in-place op applied to it with the output of the op
for node_idx, orig_node, replacement_node in replace_map:
relevant_ops = opnodes[node_idx+1:]
node_inputs_map = {}
for idx, (node_name, op_node) in enumerate(opnodes):
for inp in op_node.inputs():
if inp not in node_inputs_map:
node_inputs_map[inp] = []
node_inputs_map[inp].append((idx, op_node))

for _, node in relevant_ops:
if orig_node in node.inputs():
node.replaceInputWith(orig_node, replacement_node)
for node_idx, orig_node, replacement_node in replace_map:
if orig_node not in node_inputs_map:
continue

relevant_ops = node_inputs_map[orig_node]
begin_idx = _binray_search(relevant_ops, lambda x: x[0] > node_idx)
for idx, node in relevant_ops[begin_idx:]:
node.replaceInputWith(orig_node, replacement_node)

def from_pytorch(
script_module,
Expand All @@ -5901,7 +5909,7 @@ def from_pytorch(
use_parser_friendly_name=False,
keep_quantized_weight=False,
export_renamed_c_graph_path=None,
do_convert_params=True,
do_convert_params=True,
):
"""Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Expand Down

0 comments on commit bfb723a

Please sign in to comment.