From 313d29f1f9164ff0b6dabea088bf3574f731bdb7 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 20 Jun 2024 19:28:29 +0000 Subject: [PATCH] fix: adjust so all tests pass --- server/tests/utils/test_weights.py | 118 ++++++++++++++--------------- 1 file changed, 56 insertions(+), 62 deletions(-) diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 7d86f4f8475..cc92d199dd4 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -85,12 +85,12 @@ ), "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), - "weight.scales": torch.tensor([8], dtype=torch.float16), + "weight.scales": torch.tensor([[8]], dtype=torch.float16), "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32), }, "test_get_multi_weights_col_packed_gptq": { - "col_packed.weight.qweight": torch.tensor( + "col_packed.qweight": torch.tensor( [ [1, 2], [3, 4], @@ -99,9 +99,9 @@ ], dtype=torch.int32, ), - "col_packed.weight.g_idx": torch.tensor([1.0], dtype=torch.int32), - "col_packed.weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), - "col_packed.weight.scales": torch.tensor([8], dtype=torch.float16), + "col_packed.g_idx": torch.tensor([1.0], dtype=torch.int32), + "col_packed.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), + "col_packed.scales": torch.tensor([[8]], dtype=torch.float16), "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32), }, @@ -117,7 +117,7 @@ ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), - "weight.q_scale_max": torch.tensor([8], dtype=torch.float16), + "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_multi_weights_col_exl2": { @@ -143,11 +143,11 @@ }, "test_get_multi_weights_col_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "weight.s": torch.tensor([0.5], dtype=torch.float16), + "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, "test_get_multi_weights_col_packed_marlin": { - "col_packed.weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - "col_packed.weight.s": torch.tensor([0.5], dtype=torch.float16), + "col_packed.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "col_packed.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, } @@ -417,12 +417,12 @@ def test_get_multi_weights_row_gptq(): ) expected_weight = GPTQWeight( - qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), - qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), - scales=torch.tensor([8], dtype=torch.int32), - g_idx=torch.tensor([1.0], dtype=torch.float32), - bits=torch.tensor([8], dtype=torch.float32), - groupsize=torch.tensor([4], dtype=torch.float32), + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[1], [2]], dtype=torch.int32), + scales=torch.tensor([8.0], dtype=torch.float16), + g_idx=torch.tensor([1], dtype=torch.int32), + bits=8.0, + groupsize=4.0, use_exllama=False, ) @@ -456,12 +456,12 @@ def test_get_multi_weights_col_gptq(): ) expected_weight = GPTQWeight( - qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), - qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), - scales=torch.tensor([8], dtype=torch.int32), - g_idx=torch.tensor([1.0], dtype=torch.float32), - bits=torch.tensor([8], dtype=torch.float32), - groupsize=torch.tensor([4], dtype=torch.float32), + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[1], [2]], dtype=torch.int32), + scales=torch.tensor([[8.0]], dtype=torch.float16), + g_idx=torch.tensor([1], dtype=torch.int32), + bits=8.0, + groupsize=4.0, use_exllama=False, ) @@ -495,12 +495,12 @@ def test_get_multi_weights_col_packed_gptq(): ) expected_weight = GPTQWeight( - qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), - qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), - scales=torch.tensor([8], dtype=torch.int32), - g_idx=torch.tensor([1.0], dtype=torch.float32), - bits=torch.tensor([8], dtype=torch.float32), - groupsize=torch.tensor([4], dtype=torch.float32), + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[1], [2]], dtype=torch.int32), + scales=torch.tensor([[8.0]], dtype=torch.float16), + g_idx=torch.tensor([1], dtype=torch.int32), + bits=8.0, + groupsize=4.0, use_exllama=False, ) @@ -532,18 +532,21 @@ def test_get_multi_weights_row_exl2(): quantize=quantize, ) + scaled_scale_max = 0.3906 * 256 expected_weight = Exl2Weight( - q_weight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), q_scale=torch.tensor([8], dtype=torch.int32), - q_invperm=torch.tensor([1.0], dtype=torch.float32), - q_scale_max=8, - q_groups=torch.tensor([4], dtype=torch.int32), + q_invperm=torch.tensor([1], dtype=torch.int16), + q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16), + q_groups=torch.tensor([4], dtype=torch.int16), ) assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" - assert w.q_scale_max == expected_weight.q_scale_max + assert torch.allclose( + w.q_scale_max, expected_weight.q_scale_max + ), "q_scale_max mismatch" assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" @@ -561,25 +564,14 @@ def test_get_multi_weights_col_exl2(): prefix = "weight" quantize = "exl2" - w = weights.get_multi_weights_col( - prefix=prefix, - quantize=quantize, - dim=0, - ) - - expected_weight = Exl2Weight( - q_weight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), - q_scale=torch.tensor([8], dtype=torch.int32), - q_invperm=torch.tensor([1.0], dtype=torch.float32), - q_scale_max=8, - q_groups=torch.tensor([4], dtype=torch.int32), - ) - - assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" - assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" - assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" - assert w.q_scale_max == expected_weight.q_scale_max - assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + try: + w = weights.get_multi_weights_col( + prefixes=[prefix], + quantize=quantize, + dim=0, + ) + except ValueError as e: + assert e.args[0] == "get_multi_weights_col is not supported for exl2" def test_get_multi_weights_row_awq(): @@ -602,12 +594,12 @@ def test_get_multi_weights_row_awq(): ) expected_weight = GPTQWeight( - qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), - qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), - scales=torch.tensor([8], dtype=torch.int32), + qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32), + qzeros=torch.tensor([[1], [2]], dtype=torch.int32), + scales=torch.tensor([8.0], dtype=torch.float16), g_idx=None, - bits=torch.tensor([8], dtype=torch.float32), - groupsize=torch.tensor([4], dtype=torch.float32), + bits=8.0, + groupsize=4.0, use_exllama=False, ) @@ -654,7 +646,7 @@ def test_get_multi_weights_col_marlin(): "test_get_multi_weights_col_marlin", ], device="cpu", - dtype=torch.float32, + dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) @@ -663,14 +655,14 @@ def test_get_multi_weights_col_marlin(): quantize = "marlin" w = weights.get_multi_weights_col( - prefix=prefix, + prefixes=[prefix], quantize=quantize, dim=0, ) expected_weight = MarlinWeight( B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([0.5], dtype=torch.float16), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), ) assert torch.allclose(w.B, expected_weight.B), "B mismatch" @@ -683,7 +675,7 @@ def test_get_multi_weights_col_packed_marlin(): "test_get_multi_weights_col_packed_marlin", ], device="cpu", - dtype=torch.float32, + dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, ) @@ -692,15 +684,17 @@ def test_get_multi_weights_col_packed_marlin(): quantize = "marlin" w = weights.get_multi_weights_col( - prefix=prefix, + prefixes=[prefix], quantize=quantize, dim=0, ) expected_weight = MarlinWeight( B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), - s=torch.tensor([0.5], dtype=torch.float16), + s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), ) + print(expected_weight) + assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch"