From 73a0d4e1e78a1e49c5c538ec7c3301b07161584c Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Mon, 18 Sep 2023 19:25:23 +0800 Subject: [PATCH 1/3] Add QOP quant format --- dipoorlet/__main__.py | 7 ++++++- dipoorlet/utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/dipoorlet/__main__.py b/dipoorlet/__main__.py index 3a11509..d4878c1 100644 --- a/dipoorlet/__main__.py +++ b/dipoorlet/__main__.py @@ -2,6 +2,7 @@ import os import sys import time +import copy import onnx import torch @@ -16,7 +17,7 @@ from .tensor_cali import tensor_calibration from .utils import (ONNXGraph, load_clip_val, logger, reduce_clip_val, reduce_profiling_res, save_clip_val, save_profiling_res, - setup_logger) + setup_logger, deploy_QOperator) from .weight_transform import weight_calibration parser = argparse.ArgumentParser() @@ -50,6 +51,7 @@ parser.add_argument("--pattern", help="Sparse pattern", choices=["unstruction", "nv24"], default="unstruction") parser.add_argument("--optim_transformer", help="Transformer model optimization", default=False, action='store_true') parser.add_argument("--model_type", help="Transformer model type", choices=["unet"], default=None) +parser.add_argument("--quant_format", default="QDQ", type=str, choices=["QOP", "QDQ"]) args = parser.parse_args() if args.slurm: @@ -114,6 +116,7 @@ if dist.get_rank() == 0: logger.info("Do tensor calibration...") act_clip_val, weight_clip_val = tensor_calibration(onnx_graph, args) +tensor_range = copy.deepcopy(act_clip_val) save_clip_val(act_clip_val, weight_clip_val, args, act_fname='act_clip_val.json.rank{}'.format(args.rank), weight_fname='weight_clip_val.json.rank{}'.format(args.rank)) @@ -151,5 +154,7 @@ if dist.get_rank() == 0: logger.info("Deploy to " + args.deploy + '...') to_deploy(graph, act_clip_val, weight_clip_val, args) + if args.quant_format == 'QOP' and args.model_type is None: + deploy_QOperator(graph.model, tensor_range, args) end = time.time() logger.info("Total time cost: {} seconds.".format(int(end - start))) \ No newline at end of file diff --git a/dipoorlet/utils.py b/dipoorlet/utils.py index cde4883..69bd8a4 100644 --- a/dipoorlet/utils.py +++ b/dipoorlet/utils.py @@ -10,6 +10,8 @@ import torch.distributed as dist from onnx import TensorProto, numpy_helper from onnx.external_data_helper import convert_model_to_external_data +from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer +from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType from termcolor import colored from .platform_settings import platform_setting_table @@ -408,3 +410,26 @@ def reduce_profiling_res(rank_size, args, layer_res_fname='layer_res.json', mode model_cosine_dict[k][0] += v[0] / float(rank_size) model_cosine_dict[k][1] = min(model_cosine_dict[k][1], v[1]) return layer_cosine_dict, model_cosine_dict + + +def deploy_QOperator(model, tensor_range, args): + mode = QuantizationMode.QLinearOps + per_channel = platform_setting_table[args.deploy]['qw_params']['per_channel'] + op_types_to_quantize = platform_setting_table[args.deploy]['quant_nodes'] + + if platform_setting_table[args.deploy]['qw_params']['symmetric']: + weight_type = QuantType.QInt8 + else: + weight_type = QuantType.QUInt8 + + if platform_setting_table[args.deploy]['qi_params']['symmetric']: + activation_type = QuantType.QInt8 + else: + activation_type = QuantType.QUInt8 + + quantizer = ONNXQuantizer(model, per_channel, False, mode, True, + weight_type, activation_type, tensor_range, + None, args.skip_layers, op_types_to_quantize) + quantizer.quantize_model() + model_output = os.path.join(args.output_dir, 'qop_model.onnx') + quantizer.model.save_model_to_file(model_output) \ No newline at end of file From 07d34803a06229f3aac4fbd3206a5abc003707de Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Mon, 18 Sep 2023 20:20:51 +0800 Subject: [PATCH 2/3] Fix some bugs --- dipoorlet/__main__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dipoorlet/__main__.py b/dipoorlet/__main__.py index d4878c1..b9d9115 100644 --- a/dipoorlet/__main__.py +++ b/dipoorlet/__main__.py @@ -96,7 +96,8 @@ model = onnx.load(args.optimzed_model_dir) else: model = onnx.load(args.model) - model = onnx.version_converter.convert_version(model, 13) + if model.opset_import[0].version < 13: + model = onnx.version_converter.convert_version(model, 13) model, check = simplify(model) assert check, "Simplified ONNX model could not be validated" onnx_graph = ONNXGraph(model, args.output_dir, args.deploy, args.model_type) From 871d5dbca4868d3d0e012e768e662d37dd2034b4 Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Thu, 28 Dec 2023 16:51:03 +0800 Subject: [PATCH 3/3] Add layerwise quant error profiling --- dipoorlet/__main__.py | 37 +++++++++++++++----- dipoorlet/forward_net.py | 15 +++++++++ dipoorlet/profiling.py | 73 ++++++++++++++++++++++++++++++++++++++-- dipoorlet/quantize.py | 22 ++++++++++++ dipoorlet/utils.py | 67 ++++++++++++++++++++++++------------ requirements.txt | 2 +- 6 files changed, 182 insertions(+), 34 deletions(-) diff --git a/dipoorlet/__main__.py b/dipoorlet/__main__.py index b9d9115..dc8ec3e 100644 --- a/dipoorlet/__main__.py +++ b/dipoorlet/__main__.py @@ -13,11 +13,12 @@ from .deploy import to_deploy from .dist_helper import init_from_mpi, init_from_slurm from .profiling import (quantize_profiling_multipass, quantize_profiling_transformer, - show_model_profiling_res, show_model_ranges, weight_need_perchannel) + quantize_profiling_layerwise, show_model_profiling_res, + show_model_ranges, weight_need_perchannel) from .tensor_cali import tensor_calibration from .utils import (ONNXGraph, load_clip_val, logger, reduce_clip_val, reduce_profiling_res, save_clip_val, save_profiling_res, - setup_logger, deploy_QOperator) + setup_logger, restore_data) from .weight_transform import weight_calibration parser = argparse.ArgumentParser() @@ -51,9 +52,19 @@ parser.add_argument("--pattern", help="Sparse pattern", choices=["unstruction", "nv24"], default="unstruction") parser.add_argument("--optim_transformer", help="Transformer model optimization", default=False, action='store_true') parser.add_argument("--model_type", help="Transformer model type", choices=["unet"], default=None) -parser.add_argument("--quant_format", default="QDQ", type=str, choices=["QOP", "QDQ"]) +parser.add_argument("--onnx_sim", help="Whether use onnxsim to simplify model", action='store_true') +parser.add_argument("--qnode_version", help="The quant node opset version", type=int, choices=[13], default=13) +parser.add_argument("--layerwise_error_prof", help='Profiling per-layer quantitative error', action="store_true") +parser.add_argument("--prof_num", type=int, default=32) +parser.add_argument("--batch_data_dir", type=str, default="./batch_data/") +parser.add_argument("--criterion", help='The evaluation criterion of profiling quantitative error', type=str, + choices=['cosine', 'max_abs_gap'], default="cosine") +parser.add_argument("--sensitive_layer_num", type=int, default=10) args = parser.parse_args() +if args.layerwise_error_prof: + assert args.prof_num <= args.data_num + if args.slurm: init_from_slurm() elif args.mpirun: @@ -96,10 +107,11 @@ model = onnx.load(args.optimzed_model_dir) else: model = onnx.load(args.model) - if model.opset_import[0].version < 13: - model = onnx.version_converter.convert_version(model, 13) - model, check = simplify(model) - assert check, "Simplified ONNX model could not be validated" + if model.opset_import[0].version < args.qnode_version: + model = onnx.version_converter.convert_version(model, args.qnode_version) + if args.onnx_sim: + model, check = simplify(model, mutable_initializer=True) + assert check, "Simplified ONNX model could not be validated" onnx_graph = ONNXGraph(model, args.output_dir, args.deploy, args.model_type) if dist.get_rank() == 0 and not args.optim_transformer: @@ -151,11 +163,18 @@ show_model_ranges(graph, act_clip_val, weight_clip_val, args) weight_need_perchannel(graph, args) +# Profiling Layerwise error Distributed. +if args.layerwise_error_prof: + if dist.get_rank() == 0: + if not os.path.exists(args.batch_data_dir): + restore_data(args, graph.network_inputs, args.prof_num) + logger.info("Profiling layerwise quantitative error ...") + quantize_profiling_layerwise(graph, graph_ori, act_clip_val, weight_clip_val, args) +dist.barrier() + # Deploy if dist.get_rank() == 0: logger.info("Deploy to " + args.deploy + '...') to_deploy(graph, act_clip_val, weight_clip_val, args) - if args.quant_format == 'QOP' and args.model_type is None: - deploy_QOperator(graph.model, tensor_range, args) end = time.time() logger.info("Total time cost: {} seconds.".format(int(end - start))) \ No newline at end of file diff --git a/dipoorlet/forward_net.py b/dipoorlet/forward_net.py index 513fb2e..0fac166 100644 --- a/dipoorlet/forward_net.py +++ b/dipoorlet/forward_net.py @@ -483,3 +483,18 @@ def forward_get_tensor(graph, net, index, args): ort_outputs = ort_session.run(outputs, ort_inputs) ort_outs = OrderedDict(zip(outputs, ort_outputs)) return copy.deepcopy(ort_outs) + + +def forward_get_output(graph, net, index, args): + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + providers = [("CUDAExecutionProvider", {'device_id': device})] + ort_session = ort.InferenceSession(net.SerializeToString(), providers=providers) + ort_inputs = {} + for data in input_data_generator(args.batch_data_dir, graph.network_inputs, index, index + 1): + for name in graph.network_inputs: + ort_inputs[name] = data[name][:].reshape(graph.get_tensor_shape(name)) + outputs = [output.name for output in ort_session.get_outputs()] + ort_outputs = ort_session.run(outputs, ort_inputs) + ort_outs = OrderedDict(zip(outputs, ort_outputs)) + return copy.deepcopy(ort_outs) \ No newline at end of file diff --git a/dipoorlet/profiling.py b/dipoorlet/profiling.py index b24b894..e845f65 100644 --- a/dipoorlet/profiling.py +++ b/dipoorlet/profiling.py @@ -7,10 +7,10 @@ from onnx import numpy_helper from tqdm import tqdm -from .forward_net import forward_get_tensor, ActivationCache +from .forward_net import forward_get_tensor, forward_get_output, ActivationCache from .platform_settings import platform_setting_table -from .quantize import DQTENSORSUFFIX, QUANT_NODE_NAME_LIST, quant_graph -from .utils import cos_similarity, logger +from .quantize import DQTENSORSUFFIX, QUANT_NODE_NAME_LIST, quant_graph, insert_fake_quant_node, delete_fake_quant_node +from .utils import cos_similarity, max_abs_gap, logger def update_node_quant_profiling(graph_q, node, fp_cache, q_cache, layer_cosine_dict, args): @@ -262,3 +262,70 @@ def show_model_profiling_res(graph_after_wt, layer_cosine_dict, model_cosine_dic model_cosine_dict[name][1])) else: logger.info("{:40} tolcos : {:<.5f}".format(name, model_cosine_dict[name][0])) + + +def quantize_profiling_layerwise(graph_after_wt, graph_ori, act_clip_val, weight_clip_val, args): + clip_val = act_clip_val.copy() + clip_val.update(weight_clip_val) + graph_q, quant_node_list = quant_graph(graph_after_wt, clip_val, args) + + node_gen = tqdm(range(0, len(quant_node_list))) + + graph_q.update_model_dim(args.prof_num) + graph_ori.update_model_dim(args.prof_num) + + fp_net = graph_ori.model + q_net = graph_q.model + + fp_tensors = forward_get_output(graph_ori, fp_net, 0, args) + q_tensors = forward_get_output(graph_q, q_net, 0, args) + + base_error = 0.0 + for tensor_name in graph_q.network_outputs: + q_tensor_name = tensor_name + if tensor_name + DQTENSORSUFFIX in graph_q.network_outputs: + q_tensor_name = tensor_name + DQTENSORSUFFIX + + if args.criterion == "cosine": + base_error += 1 - cos_similarity(fp_tensors[tensor_name], q_tensors[q_tensor_name]) + elif args.criterion == "max_abs_gap": + base_error += max_abs_gap(fp_tensors[tensor_name], q_tensors[q_tensor_name]) + + base_error /= len(graph_q.network_outputs) + res = {} + act_quant = [] + for node_id in node_gen: + node = quant_node_list[node_id] + delete_fake_quant_node(graph_q, node) + graph_q.update_model() + + q_net = graph_q.model + q_tensors = forward_get_output(graph_q, q_net, 0, args) + + for tensor_name in graph_q.network_outputs: + q_tensor_name = tensor_name + if tensor_name + DQTENSORSUFFIX in graph_q.network_outputs: + q_tensor_name = tensor_name + DQTENSORSUFFIX + + if args.criterion == "cosine": + error = 1 - cos_similarity(fp_tensors[tensor_name], q_tensors[q_tensor_name]) + elif args.criterion == "max_abs_gap": + error = max_abs_gap(fp_tensors[tensor_name], q_tensors[q_tensor_name]) + + if node.name not in res: + res[node.name] = error + else: + res[node.name] += error + + res[node.name] /= len(graph_q.network_outputs) + res[node.name] = base_error - res[node.name] + logger.info(res) + + insert_fake_quant_node(graph_q, node, act_quant, clip_val, args) + graph_q.update_model() + + sorted_res = dict(sorted(res.items(), key=lambda x: x[1], reverse=True)) + logger.info(f"The sorted by quant sensitive of all layers {sorted_res}") + + sensitive_layer = list(sorted_res.keys()) + logger.info(f"The top-{args.sensitive_layer_num} sensitive layer are {sensitive_layer[:args.sensitive_layer_num]}") \ No newline at end of file diff --git a/dipoorlet/quantize.py b/dipoorlet/quantize.py index 91bb4e7..0a5d079 100644 --- a/dipoorlet/quantize.py +++ b/dipoorlet/quantize.py @@ -108,6 +108,28 @@ def insert_fake_quant_node_output(graph, clip_val, args): return +def delete_fake_quant_node(graph, node): + + def get_input_idx(node, input): + for i, inp in enumerate(node.input): + if inp == input: + return i + + inputs = copy.deepcopy(node.input) + for input in inputs: + if not input.endswith(DQTENSORSUFFIX): + continue + dequant_node = graph.get_tensor_producer(input) + quant_node = graph.get_tensor_producer(dequant_node.input[0]) + quant_consumers = graph.get_tensor_consumer(dequant_node.output[0]) + for consumer in quant_consumers: + consumer.input.insert(get_input_idx(consumer, quant_node.input[0] + DQTENSORSUFFIX), quant_node.input[0]) + if dequant_node.output[0] in consumer.input: + consumer.input.remove(dequant_node.output[0]) + graph.graph.node.remove(dequant_node) + graph.graph.node.remove(quant_node) + + def get_qnode_by_param(param, in_tensor_name, tensor_shape, range, need_transpose=False): bit_width = param['bit_width'] zero_point = [0] diff --git a/dipoorlet/utils.py b/dipoorlet/utils.py index 69bd8a4..fafebfb 100644 --- a/dipoorlet/utils.py +++ b/dipoorlet/utils.py @@ -10,8 +10,6 @@ import torch.distributed as dist from onnx import TensorProto, numpy_helper from onnx.external_data_helper import convert_model_to_external_data -from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer -from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType from termcolor import colored from .platform_settings import platform_setting_table @@ -232,6 +230,22 @@ def update_model(self): opset_imports=self.model.opset_import) self.prepare_initializer() + def update_model_dim(self, dim=32): + del self.graph.value_info[:] + for input in self.graph.input: + if input.name in self.network_inputs: + dim1 = input.type.tensor_type.shape.dim[0] + dim1.dim_value = dim + + for output in self.graph.output: + if output.name in self.network_outputs: + dim1 = output.type.tensor_type.shape.dim[0] + dim1.dim_value = dim + + self.update_model() + self.model = onnx.shape_inference.infer_shapes(self.model) + self.get_shape_type() + def copy_from(self, source_graph): self.model = copy.deepcopy(source_graph.model) self.graph = copy.deepcopy(source_graph.graph) @@ -278,6 +292,11 @@ def cos_similarity(ta, tb): / np.sqrt(np.square(tb).sum()) +def max_abs_gap(ta, tb): + assert ta.shape == tb.shape + return np.max(np.abs(ta - tb)) + + def dispatch_functool(func): registry = {} @@ -412,24 +431,30 @@ def reduce_profiling_res(rank_size, args, layer_res_fname='layer_res.json', mode return layer_cosine_dict, model_cosine_dict -def deploy_QOperator(model, tensor_range, args): - mode = QuantizationMode.QLinearOps - per_channel = platform_setting_table[args.deploy]['qw_params']['per_channel'] - op_types_to_quantize = platform_setting_table[args.deploy]['quant_nodes'] +def input_data_generator(input_dir, input_name_list, data_st_idx, data_ed_idx): + for idx in range(data_st_idx, data_ed_idx): + data = {} + for i in input_name_list: + data[i] = np.fromfile(f'{input_dir}/{i}/{idx}.bin', 'float32') + yield data - if platform_setting_table[args.deploy]['qw_params']['symmetric']: - weight_type = QuantType.QInt8 - else: - weight_type = QuantType.QUInt8 - if platform_setting_table[args.deploy]['qi_params']['symmetric']: - activation_type = QuantType.QInt8 - else: - activation_type = QuantType.QUInt8 - - quantizer = ONNXQuantizer(model, per_channel, False, mode, True, - weight_type, activation_type, tensor_range, - None, args.skip_layers, op_types_to_quantize) - quantizer.quantize_model() - model_output = os.path.join(args.output_dir, 'qop_model.onnx') - quantizer.model.save_model_to_file(model_output) \ No newline at end of file +def restore_data(args, input_name_list, batch_size=32): + + if os.path.exists(args.batch_data_dir): + logger.info("True") + os.system(f"rm -rf {args.batch_data_dir}") + + os.mkdir(args.batch_data_dir) + + for name in input_name_list: + os.mkdir(args.batch_data_dir + '/' + name) + batch_data = [] + for idx in range(0, args.data_num): + data = np.fromfile(f'{args.input_dir}/{name}/{idx}.bin', 'float32').reshape(1, -1) + batch_data.append(data) + if (idx + 1) % batch_size == 0: + batch_id = int(idx / batch_size) + batch_data = np.vstack(batch_data) + batch_data.tofile(f'{args.batch_data_dir}/{name}/{batch_id}.bin') + batch_data = [] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f05f51e..22a6775 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch -onnx>=1.10.0 +onnx onnxsim onnxruntime-gpu numpy