forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
subclass.py
599 lines (501 loc) · 19.8 KB
/
subclass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
MappingType,
)
from .utils import (
find_multiple,
dequantize_per_channel,
dynamically_quantize_per_channel,
groupwise_affine_quantize_tensor,
quant_int8_dynamic_per_token_linear,
unpack_tinygemm_scales_and_zeros,
groupwise_affine_quantize_tensor_from_qparams,
)
from torchao.utils import find_multiple
from typing import Tuple, Optional, Callable, Dict, Any
__all__ = [
"Int8DynamicallyQuantizedLinearWeight",
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
]
aten = torch.ops.aten
class QuantizedLinearWeightBase(torch.Tensor):
"""
Base quantized tensor subclass for quantized linear weights. When the from_float method is used,
to create an instance of any QuantizedLinearWeightBase, we assume the input
weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels.
The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
regardless of the internal representation's type or orientation.
"""
@staticmethod
def __new__(cls, int_data, transposed, shape, *args, **kwargs):
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
assert "dtype" in kwargs
assert not kwargs.get("requires_grad", False)
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
def __init__(self, int_data, transposed, *args, **kwargs):
self.int_data = int_data
self.transposed = transposed
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
pass
def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
)
def dequantize(self):
pass
def int_repr(self):
pass
def q_params(self):
pass
def half(self):
return self.to(torch.float16)
def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs
def _apply_fn_to_data(self, fn):
pass
def _change_shape(self):
pass
def __tensor_flatten__(self):
pass
@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
pass
@classmethod
def from_float(cls, input_float):
pass
# __torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if func is torch.nn.functional.linear:
mat1, w_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
assert w_qtensor.transposed == False
return cls._quantized_op(mat1, w_qtensor, bias)
try:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except:
print(f"ERR: subclass doesn't implement {func}")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
# for consistency and to allow people to test
# 2 - we're given non-floats - quantizing long to int8 is crazy
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].is_cuda
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
mat1, w_qtensor, bias = (
args[1],
args[2],
args[0],
)
else:
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
mat1, w_qtensor, bias = (
args[0],
args[1],
None if len(args) == 2 else args[2],
)
# call the quantized op for the specific type
# of quantized tensor subclass
return cls._quantized_op(mat1, w_qtensor, bias)
if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
if func is aten.t.default:
args[0].transposed = not args[0].transposed
new = args[0]._change_shape(args[0].shape[::-1])
return return_and_correct_aliasing(func, args, kwargs, new)
if func is aten._to_copy.default:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)
class ConstructTensorSubclass(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.args = args
self.kwargs = kwargs
def forward(self, x):
pass
def right_inverse(self, tensor_subclass_instance):
fields, _ = tensor_subclass_instance.__tensor_flatten__()
return [getattr(tensor_subclass_instance, field) for field in fields]
@torch._dynamo.allow_in_graph
def from_qtensor_components_int8dyn(*args, **kwargs):
return Int8DynamicallyQuantizedLinearWeight(*args, **kwargs)
class ConstructTensorSubclassInt8Dyn(ConstructTensorSubclass):
def forward(self, int_data, q_scales):
return from_qtensor_components_int8dyn(int_data, q_scales, *self.args, **self.kwargs)
class Int8DynamicallyQuantizedLinearWeight(QuantizedLinearWeightBase):
"""
A Tensor subclass that when applied to a weight used in a linear op/module, changes the
linear op to a dynamically quantized linear op with symmetric per-token and per-channel
quantization on the activation and weight respectively.
"""
subclass_constructor = ConstructTensorSubclassInt8Dyn
@staticmethod
def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
if dtype is None:
dtype = qscales.dtype
kwargs["dtype"] = dtype
return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined]
def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
self.q_scales = q_scales
super().__init__(int_data, transposed)
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
return quant_int8_dynamic_per_token_linear(
act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype
)
def dequantize(self, dtype=None):
"""
Obtain the dequantized version of the quantized tensor subclass
"""
zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype)
# zero_points = 0
# TODO: fix dtype here? `to(self.dtype)` is not overwritten by `dtype` arg?
dq_t = dequantize_per_channel(
self.int_data.t(), self.q_scales, zero_points, self.dtype if dtype is None else dtype
).to(self.dtype)
# data was transposed to dequantize so make sure shape is correct
return dq_t if not self.transposed else dq_t.t()
def int_repr(self):
"""
Get the internal integer representation of the quantized tensor
"""
return self.int_data if self.transposed else self.int_data.t()
def q_params(self):
"""
Get the quantization scales for the quantized tensor
"""
return {"q_scales": self.q_scales}
def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.q_scales.to(kwargs["device"]),
self.transposed,
self.shape,
**kwargs,
)
def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.q_scales),
self.transposed,
self.shape,
dtype=self.dtype,
)
# `QuantizedLinearWeightBase` inconsistently.
def _change_shape(self, shape):
return self.__class__(
self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype
)
def __tensor_flatten__(self):
# note: the order of args must match the order of args in __init__
return ["int_data", "q_scales"], [self.transposed, self.shape, self.dtype]
@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None
):
int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"]
transposed, shape, dtype = tensor_attributes
return cls(
int_data,
q_scales,
transposed,
shape if outer_size is None else outer_size,
dtype=dtype,
strides=outer_stride,
)
@classmethod
def from_float(cls, input_float, qmin=-128, qmax=127, dtype=None):
"""
Method used to convert a linear weight tensor to an instance of the
Int8DynamicallyQuantizedLinearWeight subclass.
Example usage::
model.lin_mod.weight = (
Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight)
)
"""
if dtype is None:
dtype = input_float.dtype
# because we call transpose in dequantization
w_int_repr, w_scales, _ = dynamically_quantize_per_channel(
input_float, qmin, qmax, torch.int8
)
# the desired representation shape for fast quantized matmul is
# transposed compared to how it's stored as a linear weight,
# i.e. we want in_channels as dim=0 and out_channels (and quantized axis) as dim=1
# however the external representation of our tensor will maintain the correct
# shape attribute which needs to be tracked directly.
int_data = w_int_repr.contiguous().t()
if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight):
int_data = int_data.contiguous()
return cls(
int_data, w_scales, False, input_float.shape, dtype=dtype,
)
@torch._dynamo.allow_in_graph
def from_qtensor_components_int8wo(*args, **kwargs):
return Int8WeightOnlyQuantizedLinearWeight(*args, **kwargs)
class ConstructTensorSubclassInt8wo(ConstructTensorSubclass):
def forward(self, int_data, q_scales):
return from_qtensor_components_int8wo(int_data, q_scales, *self.args, **self.kwargs)
class Int8WeightOnlyQuantizedLinearWeight(Int8DynamicallyQuantizedLinearWeight):
"""
A Tensor subclass that when applied to a weight used in a linear op/module,
changes the linear op to a weight-only quantized linear op with symmetric
per-channel quantization on the weight.
"""
subclass_constructor = ConstructTensorSubclassInt8wo
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
orig_dtype = act_mat.dtype
y = (
torch.mm(
act_mat.reshape(-1, act_mat.shape[-1]),
w_qtensor.int_data.to(act_mat.dtype),
)
* w_qtensor.q_scales
)
y = y.reshape(*act_mat.shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y.to(orig_dtype)
@torch._dynamo.allow_in_graph
def from_qtensor_components_int4wo(*args, **kwargs):
return Int4WeightOnlyQuantizedLinearWeight(*args, **kwargs)
class ConstructTensorSubclassInt4wo(ConstructTensorSubclass):
def forward(self, int_data, scales_and_zeros):
return from_qtensor_components_int4wo(int_data, scales_and_zeros, *self.args, **self.kwargs)
class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase):
"""
A Tensor subclass that when applied to a weight used in a linear op/module,
changes that linear op to a weight-only int4 quantized linear op with groupwise
affine quantization on the weight.
"""
subclass_constructor = ConstructTensorSubclassInt4wo
@staticmethod
def __new__(
cls,
int_data,
scales_and_zeros,
transposed,
shape,
groupsize=128,
inner_k_tiles=8,
dtype=None,
**kwargs,
):
if dtype is None:
dtype = scales_and_zeros.dtype
kwargs["dtype"] = dtype
return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined]
def __init__(
self,
int_data,
scales_and_zeros,
transposed,
shape,
groupsize,
inner_k_tiles,
dtype,
**kwargs,
):
# the transposed flag tracks whether the tensor subclass has been transposed relative
# to how a weight is normally stored in a linear i.e. [out_features, in_features].
# tracking both transposed and shape is slightly redundant but corner cases like
# square matrices can cause issues otherwise
self.scales_and_zeros = scales_and_zeros
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
super().__init__(int_data, transposed)
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
orig_act_size = act_mat.size()
orig_dtype = act_mat.dtype
# reshape and pad activation
act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16)
pad_size = find_multiple(act_mat.shape[-1], 1024)
act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1]))
# matmul
y = aten._weight_int4pack_mm(
act_mat.contiguous(),
w_qtensor.int_data,
w_qtensor.groupsize,
w_qtensor.scales_and_zeros,
)
# remove out_feature padding
orig_out_features = (
w_qtensor.shape[-1] if w_qtensor.transposed else w_qtensor.shape[-2]
)
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)
if bias is not None:
y += bias
return y.to(orig_dtype)
def dequantize(self):
eye_shape = self.shape[1] if not self.transposed else self.shape[0]
w_dq = self._quantized_op(
torch.eye(eye_shape, device=self.device, dtype=self.dtype), self, None
)
# we dequantized using linear with the identity matrix, output has shape [in_channels, out_channels]
# so we need to transpose back to get the original shape unless self.transposed is set.
w_dq = w_dq if self.transposed else w_dq.t()
return w_dq.to(self.dtype)
def int_repr(self):
return self.int_data
def q_params(self):
scales, zero_points = unpack_tinygemm_scales_and_zeros(
self.scales_and_zeros,
)
return {"q_scales": scales, "q_zero_points": zero_points}
def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scales_and_zeros.to(kwargs["device"]),
self.transposed,
self.shape,
self.groupsize,
self.inner_k_tiles,
**kwargs,
)
def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scales_and_zeros),
self.transposed,
self.shape,
self.groupsize,
self.inner_k_tiles,
dtype=self.dtype,
)
# `QuantizedLinearWeightBase` inconsistently.
def _change_shape(self, shape):
return self.__class__(
self.int_data,
self.scales_and_zeros,
self.transposed,
shape,
self.groupsize,
self.inner_k_tiles,
dtype=self.dtype,
)
def __tensor_flatten__(self):
return ["int_data", "scales_and_zeros"], (
self.transposed,
self.shape,
self.groupsize,
self.inner_k_tiles,
self.dtype,
)
@classmethod
# `QuantizedLinearWeightBase` inconsistently.
def __tensor_unflatten__(
cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None
):
int_data, scales_and_zeros = (
tensor_data_dict["int_data"],
tensor_data_dict["scales_and_zeros"],
)
transposed, shape, groupsize, inner_k_tiles, dtype = attributes
return cls(
int_data,
scales_and_zeros,
transposed,
shape if outer_size is None else outer_size,
groupsize,
inner_k_tiles,
dtype=dtype,
strides=outer_stride,
)
@classmethod
def from_float(cls, input_float, groupsize=128, inner_k_tiles=8, dtype=None):
"""
Method used to convert a linear weight tensor to an instance of the
Int4WeightOnlyQuantizedLinearWeight subclass.
Example usage::
model.lin_mod.weight = (
Int4WeightOnlyQuantizedLinearWeight.from_float(model.lin_mod.weight)
)
"""
if dtype is None:
dtype = input_float.dtype
int_data, scales_and_zeros, transposed, groupsize, inner_k_tils = cls.to_qtensor_components(input_float, groupsize, inner_k_tiles)
return cls(
int_data,
scales_and_zeros,
transposed,
input_float.shape,
groupsize,
inner_k_tiles,
dtype=dtype,
)
@classmethod
def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
assert groupsize in [256, 128, 64, 32]
assert inner_k_tiles in [8, 4, 2]
orig_out_features, orig_in_features = input_float.shape
# padding
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input_float = torch.nn.functional.pad(
input_float,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
# quantization and packing
input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor(
input_float, 4, groupsize, dtype=input_float.dtype
)
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles