diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index faabf48ab..273f60655 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -958,11 +958,20 @@ def test_gemlite_layout(self, device, dtype): self._test_lin_weight_subclass_api_impl( api, device, - 15, + 15, test_shape=test_shape, test_dtype=dtype, ) + # test that shapes with non divisible by 128 shapes aren't causing errors + self._test_lin_weight_subclass_api_impl( + lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)), + device, + 15, + test_shape=[1, 1025, 513], + test_dtype=dtype, + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 969816727..2df322b4f 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -15,6 +15,7 @@ from torchao.dtypes.utils import Layout, is_device from torchao.quantization.quant_primitives import quantize_affine from torchao.utils import fill_defaults +import warnings aten = torch.ops.aten @@ -76,6 +77,14 @@ def apply_gemlite_quant( out_features, in_features = weight.shape group_size = in_features if group_size is None else group_size + if in_features % 128 != 0 and out_features % 128 != 0: + warnings.simplefilter("once", UserWarning) + warnings.warn( + "Gemlite only works for layers with in_features or out_features divisible by 128, " + + "some layers have been skipped", UserWarning + ) + return weight + quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size) layout = GemlitePackedLayout( @@ -173,6 +182,10 @@ def from_plain( exhaustive=False, use_cuda_graph=False, ) + if _layout.group_size == None and _layout.bit_width == 4: + from gemlite.core import GEMLITE_ACC_DTYPE + from gemlite.dtypes import DType + GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32 out_features, in_features = int_data.shape input_dtype, output_dtype = DType.FP16, DType.FP16