Skip to content

Commit

Permalink
[resubmit] Gemlite fix
Browse files Browse the repository at this point in the history
Summary:
Resubmitting pytorch#1432 since it has some rebase issues and
we want to merge the fix asap

Test Plan:
see pytorch#1432

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Dec 18, 2024
1 parent ec64182 commit 967e35a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
11 changes: 10 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,11 +958,20 @@ def test_gemlite_layout(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
api,
device,
15,
15,
test_shape=test_shape,
test_dtype=dtype,
)

# test that shapes with non divisible by 128 shapes aren't causing errors
self._test_lin_weight_subclass_api_impl(
lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)),
device,
15,
test_shape=[1, 1025, 513],
test_dtype=dtype,
)


@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
Expand Down
13 changes: 13 additions & 0 deletions torchao/dtypes/uintx/gemlite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchao.dtypes.utils import Layout, is_device
from torchao.quantization.quant_primitives import quantize_affine
from torchao.utils import fill_defaults
import warnings

aten = torch.ops.aten

Expand Down Expand Up @@ -76,6 +77,14 @@ def apply_gemlite_quant(
out_features, in_features = weight.shape
group_size = in_features if group_size is None else group_size

if in_features % 128 != 0 and out_features % 128 != 0:
warnings.simplefilter("once", UserWarning)
warnings.warn(
"Gemlite only works for layers with in_features or out_features divisible by 128, "
+ "some layers have been skipped", UserWarning
)
return weight

quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size)

layout = GemlitePackedLayout(
Expand Down Expand Up @@ -173,6 +182,10 @@ def from_plain(
exhaustive=False,
use_cuda_graph=False,
)
if _layout.group_size == None and _layout.bit_width == 4:
from gemlite.core import GEMLITE_ACC_DTYPE
from gemlite.dtypes import DType
GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32

out_features, in_features = int_data.shape
input_dtype, output_dtype = DType.FP16, DType.FP16
Expand Down

0 comments on commit 967e35a

Please sign in to comment.