Skip to content

Commit

Permalink
Merge pull request bitsandbytes-foundation#1160 from matthewdouglas/q…
Browse files Browse the repository at this point in the history
…uant4bit-blocksize4096

Fix 4bit quantization with blocksize = 4096
  • Loading branch information
Titus-von-Koeller authored Apr 2, 2024
2 parents 2965c76 + a471456 commit 76885a4
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
7 changes: 4 additions & 3 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,11 +1087,12 @@ def get_4bit_type(typename, device=None, blocksize=64):
if data is None:
raise NotImplementedError(f"Typename {typename} not supported")

data = Tensor(data)
data /= data.abs().max()
data = torch.tensor(data, device=device)
data.div_(data.abs().max())

assert data.numel() == 16

return data.to(device)
return data


def quantize_fp4(
Expand Down
2 changes: 1 addition & 1 deletion csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;

if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024)
Expand Down
8 changes: 2 additions & 6 deletions install_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def main():
download_path = "/tmp" # default download path

if len(sys.argv) < 2:
print(
"Usage: python install_cuda.py <version/all> [user/system] [download_path]"
)
print("Usage: python install_cuda.py <version/all> [user/system] [download_path]")
sys.exit(1)

version = sys.argv[1]
Expand All @@ -100,9 +98,7 @@ def main():
elif version in cuda_versions:
install_cuda(version, base_path, download_path)
else:
print(
f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}"
)
print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}")
sys.exit(1)


Expand Down
28 changes: 23 additions & 5 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,7 +1928,9 @@ def test_bench_dequantization():


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
def test_fp4_quant(dtype):
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(dtype, quant_type, blocksize):
vals = list(product([0, 1], repeat=4))

code = {}
Expand All @@ -1953,17 +1955,33 @@ def test_fp4_quant(dtype):
code[idx] = result

A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
qa, SA = F.quantize_fp4(A1, blocksize=64)
A2 = F.dequantize_fp4(qa, SA)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)

err = (A1 - A2).abs().float()
relerr = (err / (A1.abs().float() + 1e-8)).mean()
idx = err > 1.0
err = err.mean()

assert A2.dtype == dtype
assert err.item() < 0.1
assert relerr.item() < 0.28

# With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr.
if blocksize <= 64:
assert err.item() < 0.1
assert relerr.item() < 0.28
elif blocksize <= 256:
assert err.item() < 0.11
assert relerr.item() < 0.30
elif blocksize <= 512:
assert err.item() < 0.12
assert relerr.item() < 0.31
elif quant_type == "fp4":
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
else:
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
assert err.item() < math.log2(blocksize) * 8e-2


@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
Expand Down

0 comments on commit 76885a4

Please sign in to comment.