Skip to content

Commit

Permalink
ggml-quants : use ceiling division when quantizing q1_3
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed Jun 27, 2024
1 parent 9465ec6 commit 89dc3b2
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
2 changes: 1 addition & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def write_tensors(self):
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape

# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"
shape_str = f"{{{', '.join(str(n) for n in reversed(shape)) or '1'}}}"

# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -3389,8 +3389,8 @@ void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict
int xi = nearest_int(x[j]);
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
q[j] += xt * pow3[4];
q[j] = ((uint16_t)q[j] * 256) / pow3[5];
q[j] += (uint8_t)(q[j] != 0);
// ceiling division
q[j] = ((uint16_t)q[j] * 256 + (pow3[5] - 1)) / pow3[5];
y[i].q[j] = q[j];
}
x += sizeof(y->q);
Expand All @@ -3403,8 +3403,8 @@ void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict
qb += xt * pow3[m];
}
x += 4;
qb = ((uint16_t)qb * 256) / pow3[5];
qb += (uint8_t)(qb != 0);
// ceiling division
qb = ((uint16_t)qb * 256 + (pow3[5] - 1)) / pow3[5];
y[i].qs[j] = qb;
}
}
Expand Down
3 changes: 1 addition & 2 deletions gguf-py/gguf/quants.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray:
q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True)
q48 = q48 + (q12 * 81)
q = np.concatenate([q48, q4], axis=1)
q = ((q.astype(np.uint16) * 256) // 243).astype(np.uint8)
q = np.where(q != 0, q + 1, 0)
q = (((q.astype(np.uint16) * 256) + (243 - 1)) // 243).astype(np.uint8)

return q.reshape(__quantize_q1_3_shape_change(shape))

Expand Down
6 changes: 6 additions & 0 deletions tests/test-quantize-fns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_BITNET = 0.015625f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
constexpr float MAX_DOT_PRODUCT_ERROR_BITNET = 0.5f;

static const char* RESULT_STR[] = {"ok", "FAILED"};

Expand Down Expand Up @@ -144,6 +146,8 @@ int main(int argc, char * argv[]) {
if (qfns.from_float && qfns.to_float) {
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
const float max_quantization_error =
type == GGML_TYPE_Q1_3 ? MAX_QUANTIZATION_TOTAL_ERROR_BITNET :
type == GGML_TYPE_Q2_2 ? MAX_QUANTIZATION_TOTAL_ERROR_BITNET :
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
Expand All @@ -166,6 +170,8 @@ int main(int argc, char * argv[]) {
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
? MAX_DOT_PRODUCT_ERROR_LOWBIT
: type == GGML_TYPE_Q2_2 || type == GGML_TYPE_Q1_3
? MAX_DOT_PRODUCT_ERROR_BITNET
: MAX_DOT_PRODUCT_ERROR;
failed = !(vec_dot_error < max_allowed_error);
num_failed += failed;
Expand Down

0 comments on commit 89dc3b2

Please sign in to comment.