Skip to content

Commit

Permalink
Merge branch 'arui/quantized_2' into 'main'
Browse files Browse the repository at this point in the history
Initial commit for VIT quantized

See merge request tenstorrent/tvm!48
  • Loading branch information
arui-yyz committed Feb 7, 2024
2 parents fea9b37 + dc13b67 commit 2055098
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 40 deletions.
154 changes: 115 additions & 39 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5259,19 +5259,25 @@ def try_resolve_to_const(x, dtype_override=None):
)
use_bias = len(inputs) == 9

x_zp_bcast = _op.broadcast_to(x_zero_point, infer_shape(weight))
x_cast = _op.cast(x_zp_bcast, "float32")
w_cast = _op.cast(weight, "float32")
zp_requant = _op.sum(_op.multiply(x_cast, w_cast), axis=[1, 2, 3], keepdims=False)

zp_requant = try_resolve_to_const(zp_requant, "int32")
if x_zp_value != 0:
raise RuntimeError("Non-zero input zp is not supported yet")
# x_zp_bcast = _op.broadcast_to(x_zero_point, infer_shape(weight))
# x_cast = _op.cast(x_zp_bcast, "float32")
# w_cast = _op.cast(weight, "float32")
# zp_requant = _op.sum(_op.multiply(x_cast, w_cast), axis=[1, 2, 3], keepdims=False)

# zp_requant = try_resolve_to_const(zp_requant, "int32")

# if use_bias:
# bias = _op.subtract(inputs[8], zp_requant)
# out = _op.nn.bias_add(out, bias)
# else:
# bias = _op.broadcast_to(_op.negative(zp_requant), infer_shape(out)[1:2])
# out = _op.nn.bias_add(out, bias)

if use_bias:
bias = _op.subtract(inputs[8], zp_requant)
out = _op.nn.bias_add(out, bias)
else:
bias = _op.broadcast_to(_op.negative(zp_requant), infer_shape(out)[1:2])
out = _op.nn.bias_add(out, bias)
if use_bias:
out = _op.nn.bias_add(out, inputs[8])

infer_type(out)
out_dtype = infer_type(inputs[7]).checked_type.dtype
Expand Down Expand Up @@ -5452,10 +5458,14 @@ def try_resolve_to_const(x, dtype_override=None):
trans_a = attr.get("transA", False)
trans_b = attr.get("transB", False)

if trans_a:
a = _op.transpose(a, axes=[1,0])
if not trans_b:
b = _op.transpose(b, axes=[1,0])
# if trans_a:
# axes = np.arange(len(infer_shape(a)))
# axes[-2:] = [axes[-1],] + [axes[-2],]
# a = _op.transpose(a, axes=axes)
# if not trans_b:
# axes = np.arange(len(infer_shape(b)))
# axes[-2:] = [axes[-1],] + [axes[-2],]
# b = _op.transpose(b, axes=axes)

a_type = infer_type(a).checked_type # 'T1' in ONNX doc for this op
a_scale_type = infer_type(a_scale).checked_type
Expand Down Expand Up @@ -5519,35 +5529,86 @@ def try_resolve_to_const(x, dtype_override=None):
# if a_type.dtype == "uint8" and b_type.dtype == "uint8":
# matmul_result_dtype = "uint32"

matmul_result = _op.nn.dense(
a,
b,
out_dtype=matmul_result_dtype,
# a_zp_scalar,
# b_zp_scalar,
# a_scale_scalar,
# b_scale_scalar,
# num_hidden_units,
# matmul_result_dtype,
)

matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar))
if len(a_shape) == 2:
a_ = _op.reshape(a, [1,] + list(a_shape))
b_ = _op.reshape(b, [1,] + list(b_shape))
matmul_result = _op.nn.batch_matmul(
a_,
b_,
out_dtype=matmul_result_dtype,
transpose_a=trans_a,
transpose_b=trans_b,
)

a_bcast = _op.broadcast_to(a_zp, a_shape)
a_cast = _op.cast(a_bcast, "float32")
b_cast = _op.cast(b, "float32")
elif len(a_shape) == 3:
a_ = a
if len(b_shape) == 2:
b_ = _op.reshape(b, [1,] + list(b_shape))
b_ = _op.broadcast_to(b_, [a_shape[0],] + list(b_shape))
elif len(b_shape) == 4:
b_ = _op.reshape(b, [b_shape[0] * b_shape[1],] + list(b_shape[2:]))
else:
b_ = b

matmul_result = _op.nn.batch_matmul(
a_,
b_,
out_dtype=matmul_result_dtype,
transpose_a=trans_a,
transpose_b=trans_b,
)

matmul_result_zp_scalar = _op.nn.dense(a_cast, b_cast, out_dtype="float32")
else:
assert len(a_shape) == 4
a_ = _op.reshape(a, [a_shape[0] * a_shape[1],] + list(a_shape[2:]))
if len(b_shape) == 2:
b_ = _op.reshape(b, [1,] + list(b_shape))
b_ = _op.broadcast_to(b_, [a_shape[0] * a_shape[1],] + list(b_shape))
elif len(b_shape) == 4:
b_ = _op.reshape(b, [b_shape[0] * b_shape[1],] + list(b_shape[2:]))
else:
b_ = b

matmul_result = _op.nn.batch_matmul(
a_,
b_,
out_dtype=matmul_result_dtype,
transpose_a=trans_a,
transpose_b=trans_b,
)

matmul_result_zp_scalar = try_resolve_to_const(matmul_result_zp_scalar, "int32")
matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar))

if bias is not None:
matmul_result_zp_scalar = _op.reshape(matmul_result_zp_scalar, infer_shape(bias))
bias = _op.subtract(bias, matmul_result_zp_scalar)
matmul_result = _op.nn.bias_add(matmul_result, bias)
if a_zp.data.numpy().item() != 0:
raise RuntimeError("Do not support non-zero zero point yet")
# a_bcast = _op.broadcast_to(a_zp, infer_shape(a_))
# if len(infer_shape(a_bcast)) == 2:
# matmul_result_zp_scalar = _op.nn.dense(a_bcast, b_, out_dtype="int32")
# elif len(infer_shape(a_bcast)) == 3:
# assert len(infer_shape(b_)) == 3
# matmul_result_zp_scalar = _op.nn.batch_matmul(
# a_bcast,
# b_,
# out_dtype="int32",
# transpose_a=trans_a,
# transpose_b=trans_b,
# )
# else:
# raise RuntimeError("Do not support 4D matmul yet")


# matmul_result_zp_scalar = try_resolve_to_const(matmul_result_zp_scalar, "int32")

# if bias is not None:
# matmul_result_zp_scalar = _op.reshape(matmul_result_zp_scalar, infer_shape(bias))
# bias = _op.subtract(bias, matmul_result_zp_scalar)
# matmul_result = _op.add(matmul_result, bias)
# else:
# bias = _op.broadcast_to(_op.negative(matmul_result_zp_scalar), infer_shape(matmul_result))
# matmul_result = _op.add(matmul_result, bias)
else:
bias = _op.broadcast_to(_op.negative(matmul_result_zp_scalar), infer_shape(matmul_result)[1:])
matmul_result = _op.nn.bias_add(matmul_result, bias)
matmul_result_zp_scalar = _op.const(0, dtype="int8")


# This information might only be found in the C++ code-comments for the
# dense.matmul op, but the quantized tensor returned by _qnn.op.dense
Expand All @@ -5556,6 +5617,15 @@ def try_resolve_to_const(x, dtype_override=None):
# 'matmul_result_zp_scalar' has type 'int32' to satisfy input requirements
# of the [de/re]quantize ops below.

# Change back to 4D if needed
if len(a_shape) == 4:
result_shape = infer_shape(matmul_result)
matmul_result = _op.reshape(
matmul_result, [a_shape[0], a_shape[1], result_shape[-2], result_shape[-1]]
)
elif len(a_shape) == 2:
result_shape = infer_shape(matmul_result)
matmul_result = _op.reshape(matmul_result, result_shape[1:])

if "int32" in expected_out_dtypes:
# This is the adaptation of the QLinearMatMul converter for MatMulInteger,
Expand All @@ -5565,6 +5635,10 @@ def try_resolve_to_const(x, dtype_override=None):
# requantize requires y_scale to be constant,
# if y_scale is not constant, doing dequantize -> quantize
if isinstance(y_scale_scalar, _expr.Constant):
# y_scale_scalar = _op.broadcast_to(y_scale_scalar, infer_shape(matmul_result))
# matmul_result_scale_scalar = _op.broadcast_to(
# matmul_result_scale_scalar, infer_shape(matmul_result)
# )
y = _qnn.op.requantize(
matmul_result,
matmul_result_scale_scalar,
Expand Down Expand Up @@ -7366,6 +7440,8 @@ def _convert_operator(self, op_name, inputs, attrs, opset):
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs, self._params)
# print("Converted {} to ".format(op_name))
# print(infer_shape(sym))
else:
raise NotImplementedError(f"Operator {op_name} not implemented.")
return sym
Expand Down
134 changes: 133 additions & 1 deletion python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,7 @@ def callback(self, pre, post, node_map):
if transpose_b:
b = tvm.relay.transpose(b, axes=axes)

return tvm.relay.nn.batch_matmul(a, b, transpose_a=False, transpose_b=False)
return tvm.relay.nn.batch_matmul(a, b, transpose_a=False, transpose_b=False, out_dtype=post.checked_type.dtype)

class LowerAdaptiveAvgPool(DFPatternCallback):
def __init__(self):
Expand Down Expand Up @@ -2201,6 +2201,135 @@ def callback(self, pre, post, node_map):
return tvm.relay.gelu(node_map[self.act][0])


class RemoveQuantDequantSequence(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
self.act = wildcard()
self.quant = is_op("qnn.quantize")(self.act, wildcard(), wildcard(),)
self.pattern = is_op("qnn.dequantize")(self.quant, wildcard(), wildcard(),)

def callback(self, pre, post, node_map):
act = node_map[self.act][0]
quant = node_map[self.quant][0]
return node_map[self.act][0]


class ReconstructOnnxQuantizedGelu(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
self.act = wildcard()
self.one_over_root_two = wildcard()
self.one = wildcard()
self.half_added = wildcard()

self.act_0 = is_op("qnn.dequantize")(self.act, wildcard(), wildcard(),)
self.act_1 = is_op("qnn.dequantize")(self.act, wildcard(), wildcard(),)

times_root_two = is_op("multiply")(self.act_0, self.one_over_root_two)
erf = is_op("erf")(times_root_two)

# CHECK IF WE NEED IT
# quantize_erf = is_op("qnn.quantize")(erf, wildcard(), wildcard(),)
# dequantize_erf = is_op("qnn.dequantize")(quantize_erf, wildcard(), wildcard(),)


add = is_op("add")(erf, self.one)

# CHECK
# quantize_add = is_op("qnn.quantize")(add, wildcard(), wildcard(),)
# dequantize_add = is_op("qnn.dequantize")(quantize_add, wildcard(), wildcard(),)

mult2 = is_op("multiply")(self.act_1, add)

# Check
# quantize_mult2 = is_op("qnn.quantize")(mult2, wildcard(), wildcard(),)
# dequantize_mult2 = is_op("qnn.dequantize")(quantize_mult2, wildcard(), wildcard(),)

gelu = is_op("multiply")(mult2, self.half_added)

self.pattern = gelu

def callback(self, pre, post, node_map):
if isinstance(node_map[self.half_added][0], tvm.relay.expr.Constant):
half_added = math.isclose(node_map[self.half_added][0].data.numpy(), 0.5, rel_tol=1e-6, abs_tol=1e-6)
elif isinstance(node_map[self.half_added][0].args[0], tvm.relay.expr.Constant):
# Compute dequant
op = node_map[self.half_added][0]
assert (op.op.name == "qnn.dequantize")
input_int = op.args[0].data.numpy()
input_scale = op.args[1].data.numpy()
input_zp = op.args[2].data.numpy()
float_ = (float)(input_int - input_zp) * input_scale
half_added = math.isclose(float_, 0.5, rel_tol=1e-6, abs_tol=1e-6)
else:
return post

if isinstance(node_map[self.one][0], tvm.relay.expr.Constant):
one_added = math.isclose(node_map[self.one][0].data.numpy(), 1.0, rel_tol=1e-6, abs_tol=1e-6)
elif isinstance(node_map[self.one][0].args[0], tvm.relay.expr.Constant):
# Compute dequant
op = node_map[self.one][0]
assert (op.op.name == "qnn.dequantize")
input_int = op.args[0].data.numpy()
input_scale = op.args[1].data.numpy()
input_zp = op.args[2].data.numpy()
float_ = (float)(input_int - input_zp) * input_scale
one_added = math.isclose(float_, 1.0, rel_tol=1e-6, abs_tol=1e-6)
else:
return post

sqrt_half = node_map[self.one_over_root_two][0]
# Relay graph may use sqrt(1/2) outright, or take the recipricoral of sqrt(2)
if isinstance(sqrt_half, tvm.relay.expr.Constant):
sqrt_half = sqrt_half.data.numpy()
root_two_multiplied = math.isclose(sqrt_half, 0.70710677, rel_tol=1e-6, abs_tol=1e-6)
else:
sqrt_half = sqrt_half.args[0].data.numpy()
root_two_multiplied = math.isclose(sqrt_half, 1.4142135, rel_tol=1e-6, abs_tol=1e-6)

if not (half_added and one_added and root_two_multiplied):
return post

quantize_act = node_map[self.act][0]
original_dequant = node_map[self.act_0][0]
dequant_act = tvm.relay.qnn.op.dequantize(quantize_act, original_dequant.args[1], original_dequant.args[2])
return tvm.relay.gelu(dequant_act)


class DecomposeQnnConcat(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
self.pattern = is_op("qnn.concatenate")(wildcard(), wildcard(), wildcard(),wildcard(),wildcard(),)


def callback(self, pre, post, node_map):
data = post.args[0]
input_scales = post.args[1]
input_zps = post.args[2]
output_scale = post.args[3]
output_zp = post.args[4]

assert len(input_scales) == len(input_zps) == len(data)
new_concat_inputs = []
for i in range(len(data)):
if input_scales[i].data.numpy() == output_scale.data.numpy() and input_zps[i].data.numpy() == output_zp.data.numpy():
new_concat_inputs.append(data[i])
else:
# Insert requant
inp = tvm.relay.qnn.op.requantize(
data[i],
input_scale=input_scales[i],
input_zero_point=input_zps[i],
output_scale=output_scale,
output_zero_point=output_zp,
out_dtype=post.checked_type.dtype,
)
new_concat_inputs.append(inp)

return tvm.relay.concatenate(new_concat_inputs, axis=post.attrs.axis)



class ReconstructPyTorchGeluNew(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)
Expand Down Expand Up @@ -3636,6 +3765,9 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None,
ConvertGlobalAvgPool2dtoAvgPool2d(),
ConvertUpsampleToResize2d(),
DecomposeMultiIndexAdvIndex(),
RemoveQuantDequantSequence(),
ReconstructOnnxQuantizedGelu(),
DecomposeQnnConcat(),
# DecomposeErf(),
ReconstructTFGelu(),
ReconstructOnnxGelu(),
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/contrib/buda/relay_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def run_relay_compile_passes(relay_module, print_all=False):
logger.trace("After CanonicalizeOps")
logger.trace(relay_module.functions)

# relay_module = tvm.transform.Sequential([relay.qnn.transform.CanonicalizeOps()])(relay_module)
# logger.trace("After QNN CanonicalizeOps")
# logger.trace(relay_module.functions)

relay_module = tvm.transform.Sequential([transform.InferType()])(relay_module)
logger.trace("After InferType")
logger.trace(relay_module.functions)
Expand Down

0 comments on commit 2055098

Please sign in to comment.