From 89dc3b254cb85d879559007a4a5d1a5cce62822f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 26 Jun 2024 15:31:48 -0400 Subject: [PATCH] ggml-quants : use ceiling division when quantizing q1_3 --- convert-hf-to-gguf.py | 2 +- ggml/src/ggml-quants.c | 8 ++++---- gguf-py/gguf/quants.py | 3 +-- tests/test-quantize-fns.cpp | 6 ++++++ 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a8aef09b93369..ec66316ee7804 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -336,7 +336,7 @@ def write_tensors(self): shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape # reverse shape to make it similar to the internal ggml dimension order - shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" + shape_str = f"{{{', '.join(str(n) for n in reversed(shape)) or '1'}}}" # n_dims is implicit in the shape logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 14a1ee4e97e8e..5dd682b602d56 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3389,8 +3389,8 @@ void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict int xi = nearest_int(x[j]); uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2; q[j] += xt * pow3[4]; - q[j] = ((uint16_t)q[j] * 256) / pow3[5]; - q[j] += (uint8_t)(q[j] != 0); + // ceiling division + q[j] = ((uint16_t)q[j] * 256 + (pow3[5] - 1)) / pow3[5]; y[i].q[j] = q[j]; } x += sizeof(y->q); @@ -3403,8 +3403,8 @@ void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict qb += xt * pow3[m]; } x += 4; - qb = ((uint16_t)qb * 256) / pow3[5]; - qb += (uint8_t)(qb != 0); + // ceiling division + qb = ((uint16_t)qb * 256 + (pow3[5] - 1)) / pow3[5]; y[i].qs[j] = qb; } } diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index 46820dce3b288..c66b83b3f8283 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -149,8 +149,7 @@ def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray: q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True) q48 = q48 + (q12 * 81) q = np.concatenate([q48, q4], axis=1) - q = ((q.astype(np.uint16) * 256) // 243).astype(np.uint8) - q = np.where(q != 0, q + 1, 0) + q = (((q.astype(np.uint16) * 256) + (243 - 1)) // 243).astype(np.uint8) return q.reshape(__quantize_q1_3_shape_change(shape)) diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index e690ac6c85a71..d977aa26bc00e 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -15,11 +15,13 @@ constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_BITNET = 0.015625f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f; constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f; +constexpr float MAX_DOT_PRODUCT_ERROR_BITNET = 0.5f; static const char* RESULT_STR[] = {"ok", "FAILED"}; @@ -144,6 +146,8 @@ int main(int argc, char * argv[]) { if (qfns.from_float && qfns.to_float) { const float total_error = total_quantization_error(qfns, test_size, test_data.data()); const float max_quantization_error = + type == GGML_TYPE_Q1_3 ? MAX_QUANTIZATION_TOTAL_ERROR_BITNET : + type == GGML_TYPE_Q2_2 ? MAX_QUANTIZATION_TOTAL_ERROR_BITNET : type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : @@ -166,6 +170,8 @@ int main(int argc, char * argv[]) { const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT + : type == GGML_TYPE_Q2_2 || type == GGML_TYPE_Q1_3 + ? MAX_DOT_PRODUCT_ERROR_BITNET : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error); num_failed += failed;