-
Notifications
You must be signed in to change notification settings - Fork 10
/
quant_utils.py
95 lines (84 loc) · 3.05 KB
/
quant_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# *
# @file Different utility functions
# Copyright (c) Cong Guo, Yuxian Qiu, Jingwen Leng, Xiaotian Gao,
# Chen Zhang, Yunxin Liu, Fan Yang, Yuhao Zhu, Minyi Guo
# All rights reserved.
# This file is part of SQuant repository.
#
# SQuant is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# SQuant is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SQuant repository. If not, see <http://www.gnu.org/licenses/>.
# *
import os
import torch
import logging
import uuid
from quant_modules import Quantizer
quant_args = {}
def set_quantizer(args):
global quant_args
quant_args.update({'mode' : args.mode, 'wbit': args.wbit, 'abit': args.abit, 'args' : args})
logger = logging.getLogger(__name__)
def set_util_logging(filename):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO,
handlers=[
logging.FileHandler(filename),
logging.StreamHandler()
]
)
def tag_info(args):
if args.tag == "":
return ""
else:
return "_" + args.tag
def get_ckpt_path(args):
path='squant_log'
if not os.path.isdir(path):
os.mkdir(path)
path = os.path.join(path, args.model+"_"+args.dataset)
if not os.path.isdir(path):
os.mkdir(path)
pathname = args.mode + '_W' + str(args.wbit) + 'A' + str(args.wbit)
num = int(uuid.uuid4().hex[0:4], 16)
pathname += '_' + str(num)
path = os.path.join(path, pathname)
if not os.path.isdir(path):
os.mkdir(path)
return path
def get_ckpt_filename(path, epoch):
return os.path.join(path, 'ckpt_' + str(epoch) + '.pth')
def get_log_filename(args):
dire = ['checkpoint', args.dataset, args.model]
path=''
for d in dire:
path = os.path.join(path, d)
if not os.path.isdir(path):
os.mkdir(path)
return os.path.join(path, 'ckpt_' + args.mode + '_' + '_'.join(map(lambda x: str(x), args.bit)) + tag_info(args) + '.txt')
def disable_input_quantization(model):
for name, module in model.named_modules():
if isinstance(module, Quantizer):
module.disable_input_quantization()
def enable_quantization(model):
for name, module in model.named_modules():
# print("Enabling module:", name)
if isinstance(module, Quantizer):
# print("Enabling module:", name)
module.enable_quantization(name)
def disable_quantization(model):
for name, module in model.named_modules():
if isinstance(module, Quantizer):
# print("Disabling module:", name)
module.disable_quantization(name)