-
Notifications
You must be signed in to change notification settings - Fork 10
/
squant_function.py
93 lines (83 loc) · 2.26 KB
/
squant_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def SQuant_func(
flip_number,
flip_up_number,
flip_down_number,
rounding_error_sum,
rounding_number,
rounding_error,
up_number,
up_error,
up_priority,
up_order,
down_number,
down_error,
down_priority,
down_order,
):
conver_shape = rounding_number.shape
for oc in range(conver_shape[0]):
for ic in range(conver_shape[1]):
round_func(
flip_number,
flip_up_number,
flip_down_number,
rounding_error_sum[oc][ic],
rounding_number[oc][ic],
rounding_error[oc][ic],
up_number[oc][ic],
up_error[oc][ic],
up_priority[oc][ic],
up_order[oc][ic],
down_number[oc][ic],
down_error[oc][ic],
down_priority[oc][ic],
down_order[oc][ic],
)
def round_func(
flip_number,
flip_up_number,
flip_down_number,
rounding_error_sum,
rounding_number_,
rounding_error_,
up_number_,
up_error_,
up_priority_,
up_order_,
down_number_,
down_error_,
down_priority_,
down_order_,
):
if rounding_error_sum < 0:
# print("UP")
number_ = up_number_
error_ = up_error_
priority_ = up_priority_
order_ = up_order_
error_1 = down_error_
priority_1 = down_priority_
flip_number_ = flip_up_number
is_up = True
else:
# print("Down")
number_ = down_number_
error_ = down_error_
priority_ = down_priority_
order_ = down_order_
error_1 = up_error_
priority_1 = up_priority_
flip_number_ = flip_down_number
is_up = False
rounding_error_sum = rounding_error_sum.abs()
topk = int(rounding_error_sum.round().item())
over_squant = (topk >= rounding_error_sum)
idx_ = order_[0:topk]
rounding_error_[idx_] = error_[idx_]
rounding_number_[idx_] = number_[idx_]
if over_squant:
idx_c = order_[topk - 1]
priority_1[idx_c] = rounding_error_[idx_c].abs()
else:
idx_c = order_[topk]
priority_[idx_c] = rounding_error_[idx_c].abs()