diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 4651f8d2020..7d86f4f8475 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -65,12 +65,43 @@ [5, 6], [7, 8], ], - dtype=torch.float32, + dtype=torch.int32, ), - "weight.g_idx": torch.tensor([1.0], dtype=torch.float32), - "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.float32), - "weight.scales": torch.tensor([8], dtype=torch.int32), - # + "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), + "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), + "weight.scales": torch.tensor([8], dtype=torch.float16), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([4], dtype=torch.float32), + }, + "test_get_multi_weights_col_gptq": { + "weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.g_idx": torch.tensor([1.0], dtype=torch.int32), + "weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), + "weight.scales": torch.tensor([8], dtype=torch.float16), + "gptq_bits": torch.tensor([8], dtype=torch.float32), + "gptq_groupsize": torch.tensor([4], dtype=torch.float32), + }, + "test_get_multi_weights_col_packed_gptq": { + "col_packed.weight.qweight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "col_packed.weight.g_idx": torch.tensor([1.0], dtype=torch.int32), + "col_packed.weight.qzeros": torch.tensor([[1.0], [2.0]], dtype=torch.int32), + "col_packed.weight.scales": torch.tensor([8], dtype=torch.float16), "gptq_bits": torch.tensor([8], dtype=torch.float32), "gptq_groupsize": torch.tensor([4], dtype=torch.float32), }, @@ -82,18 +113,42 @@ [5, 6], [7, 8], ], - dtype=torch.float32, + dtype=torch.int32, ), "weight.q_scale": torch.tensor([8], dtype=torch.int32), - "weight.q_invperm": torch.tensor([1.0], dtype=torch.float32), - "weight.q_scale_max": 8, - "weight.q_groups": torch.tensor([4], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([8], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), + }, + "test_get_multi_weights_col_exl2": { + "weight.q_weight": torch.tensor( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + ], + dtype=torch.int32, + ), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_scale": torch.tensor([8], dtype=torch.int32), + "weight.q_invperm": torch.tensor([1.0], dtype=torch.int32), + "weight.q_scale_max": torch.tensor([8], dtype=torch.float16), + "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, "test_get_multi_weights_row_marlin": { - "weight.scales": torch.tensor([8], dtype=torch.float16), "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([0.5], dtype=torch.float16), }, + "test_get_multi_weights_col_marlin": { + "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "weight.s": torch.tensor([0.5], dtype=torch.float16), + }, + "test_get_multi_weights_col_packed_marlin": { + "col_packed.weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + "col_packed.weight.s": torch.tensor([0.5], dtype=torch.float16), + }, } @@ -380,6 +435,84 @@ def test_get_multi_weights_row_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" +def test_get_multi_weights_col_gptq(): + weights = MockWeights( + [ + "test_get_multi_weights_col_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["weight"] + quantize = "gptq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), + scales=torch.tensor([8], dtype=torch.int32), + g_idx=torch.tensor([1.0], dtype=torch.float32), + bits=torch.tensor([8], dtype=torch.float32), + groupsize=torch.tensor([4], dtype=torch.float32), + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + +def test_get_multi_weights_col_packed_gptq(): + weights = MockWeights( + [ + "test_get_multi_weights_col_packed_gptq", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefixes = ["col_packed"] + quantize = "gptq" + + w = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + + expected_weight = GPTQWeight( + qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + qzeros=torch.tensor([[1.0], [2.0]], dtype=torch.float32), + scales=torch.tensor([8], dtype=torch.int32), + g_idx=torch.tensor([1.0], dtype=torch.float32), + bits=torch.tensor([8], dtype=torch.float32), + groupsize=torch.tensor([4], dtype=torch.float32), + use_exllama=False, + ) + + assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch" + assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch" + assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch" + assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" + assert w.bits == expected_weight.bits, "bits mismatch" + assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" + + def test_get_multi_weights_row_exl2(): weights = MockWeights( [ @@ -414,6 +547,41 @@ def test_get_multi_weights_row_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" +def test_get_multi_weights_col_exl2(): + weights = MockWeights( + [ + "test_get_multi_weights_col_exl2", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "exl2" + + w = weights.get_multi_weights_col( + prefix=prefix, + quantize=quantize, + dim=0, + ) + + expected_weight = Exl2Weight( + q_weight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]), + q_scale=torch.tensor([8], dtype=torch.int32), + q_invperm=torch.tensor([1.0], dtype=torch.float32), + q_scale_max=8, + q_groups=torch.tensor([4], dtype=torch.int32), + ) + + assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch" + assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch" + assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch" + assert w.q_scale_max == expected_weight.q_scale_max + assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" + + def test_get_multi_weights_row_awq(): weights = MockWeights( [ @@ -478,3 +646,61 @@ def test_get_multi_weights_row_marlin(): assert torch.allclose(w.B, expected_weight.B), "B mismatch" assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +def test_get_multi_weights_col_marlin(): + weights = MockWeights( + [ + "test_get_multi_weights_col_marlin", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "weight" + quantize = "marlin" + + w = weights.get_multi_weights_col( + prefix=prefix, + quantize=quantize, + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([0.5], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch" + + +def test_get_multi_weights_col_packed_marlin(): + weights = MockWeights( + [ + "test_get_multi_weights_col_packed_marlin", + ], + device="cpu", + dtype=torch.float32, + process_group=dummy_process_group, + dummy_fs=dummy_file_system, + ) + + prefix = "col_packed" + quantize = "marlin" + + w = weights.get_multi_weights_col( + prefix=prefix, + quantize=quantize, + dim=0, + ) + + expected_weight = MarlinWeight( + B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), + s=torch.tensor([0.5], dtype=torch.float16), + ) + + assert torch.allclose(w.B, expected_weight.B), "B mismatch" + assert torch.allclose(w.s, expected_weight.s), "s mismatch"