From 5159d030a98553a8bf16a76a1dc63e04a67f9b5c Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:12:08 -0500 Subject: [PATCH] bitsandbytes: upgrade and enable CUDA Graphs for 4bit by default --- .../source/basic_tutorials/preparing_model.md | 2 +- docs/source/index.md | 2 +- launcher/src/main.rs | 11 ++----- server/poetry.lock | 13 +++++---- server/pyproject.toml | 2 +- server/text_generation_server/layers/bnb.py | 29 ++++--------------- 6 files changed, 18 insertions(+), 41 deletions(-) diff --git a/docs/source/basic_tutorials/preparing_model.md b/docs/source/basic_tutorials/preparing_model.md index 456ade44282..2f726247394 100644 --- a/docs/source/basic_tutorials/preparing_model.md +++ b/docs/source/basic_tutorials/preparing_model.md @@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects. ## Quantization -TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization) +TGI supports [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes#bitsandbytes), [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPTQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization) ## RoPE Scaling diff --git a/docs/source/index.md b/docs/source/index.md index 9a6e1774c3c..44f4498dc56 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -12,7 +12,7 @@ Text Generation Inference implements many optimizations and features, such as: - Token streaming using Server-Sent Events (SSE) - Continuous batching of incoming requests for increased total throughput - Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures -- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) +- Quantization with [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) - [Safetensors](https://github.com/huggingface/safetensors) weight loading - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Logits warper (temperature scaling, top-p, top-k, repetition penalty) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fb6ba2b2554..e48628a0c4a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -2075,15 +2075,8 @@ fn main() -> Result<(), LauncherError> { let cuda_graphs = match (&args.cuda_graphs, &quantize) { (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), #[allow(deprecated)] - ( - None, - Some( - Quantization::Bitsandbytes - | Quantization::BitsandbytesNf4 - | Quantization::BitsandbytesFp4, - ), - ) => { - tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + (None, Some(Quantization::Bitsandbytes)) => { + tracing::warn!("Bitsandbytes 8bit doesn't work with cuda graphs, deactivating them"); vec![] } (None, Some(Quantization::Exl2)) => { diff --git a/server/poetry.lock b/server/poetry.lock index 7cf440dd3e1..f5ad247cd09 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "accelerate" @@ -290,22 +290,23 @@ tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "bitsandbytes" -version = "0.43.3" +version = "0.45.0" description = "k-bit optimizers and matrix multiplication routines." optional = true python-versions = "*" files = [ - {file = "bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:cc99507c352be0715098b2c7577b690dd158972dc4ea10c7495bac104c7c79f0"}, - {file = "bitsandbytes-0.43.3-py3-none-win_amd64.whl", hash = "sha256:257f6552f2144748a84e6c44e1f7a98f3da888f675ed74e18fd7f7eb13c6cafa"}, + {file = "bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:0f0323de1ff1fdf8383e79bdad1283516a4c05a6fd2b44a363bf4e059422305b"}, + {file = "bitsandbytes-0.45.0-py3-none-win_amd64.whl", hash = "sha256:ebbf96e0ecb466716a65ecdeaef3fa1983575447b9ab66b74e5211892507c6ff"}, ] [package.dependencies] numpy = "*" torch = "*" +typing_extensions = ">=4.8.0" [package.extras] benchmark = ["matplotlib", "pandas"] -test = ["scipy"] +test = ["lion_pytorch", "scipy"] [[package]] name = "certifi" @@ -4097,4 +4098,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "c7fdcff2b752cd3beb3995c1ecd15f0f4d9b4e117048b06ab991c6d0e0c86ff3" +content-hash = "767757fffcf7bec05a8a60dcfe2a3c7d258f26efac3004f3d24c8d543b462413" diff --git a/server/pyproject.toml b/server/pyproject.toml index 0d56e9c7375..ca577c40c01 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.4" typer = "^0.12.5" accelerate = {version = "^1.1.0", optional = true} -bitsandbytes = { version = "^0.43.0", optional = true } +bitsandbytes = { version = "^0.45.0", optional = true } safetensors = "^0.4.5" loguru = "^0.7.2" opentelemetry-api = "^1.27.0" diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index 791d9b6d8c6..78cecd9a168 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -20,23 +20,16 @@ def __init__( weight, bias, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None, ): super().__init__() - assert ( - not memory_efficient_backward - ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index # Necessary for stacked layers self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights - self.state.memory_efficient_backward = memory_efficient_backward - if threshold > 0.0 and not has_fp16_weights: - self.state.use_pool = True self.weight = Int8Params( weight.data, @@ -63,12 +56,9 @@ def forward(self, x: torch.Tensor): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights: - if self.state.CB is not None and self.state.CxB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + if not self.state.has_fp16_weights and self.state.CB is not None: + self.weight.data = self.state.CB + return out @@ -106,19 +96,12 @@ def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, "quant_state", None) is None: - print( - "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." - ) inp_dtype = x.dtype if self.compute_dtype is not None: x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - out = bnb.matmul_4bit( - x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state - ) - out = out.to(inp_dtype) - - return out + return bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ).to(inp_dtype)