Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU][Quantization] TPU W8A8 #11785

Merged
merged 73 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
3b0c8a6
w8a8 working
robertgshaw2-neuralmagic Oct 11, 2024
36fc1db
format
robertgshaw2-neuralmagic Oct 11, 2024
d83c04c
added all kernels
robertgshaw2-neuralmagic Oct 11, 2024
af9d0f4
format
robertgshaw2-neuralmagic Oct 11, 2024
0f9fd21
working on cuda
robertgshaw2-neuralmagic Oct 12, 2024
7b3203f
added mixed precision directory
robertgshaw2-neuralmagic Oct 12, 2024
bf50fa4
formatting
robertgshaw2-neuralmagic Oct 12, 2024
226ef52
cache current state - w8a16 running oom
robertgshaw2-neuralmagic Oct 12, 2024
bb7c741
[TPU] Ensure torch._sync(param) is called after param.data.copy_()
WoosukKwon Oct 16, 2024
cf842bd
yapf
WoosukKwon Oct 17, 2024
67039bc
[TPU] Correctly profile peak memory usage
WoosukKwon Oct 17, 2024
0695f77
Upgrade PyTorch XLA
WoosukKwon Oct 17, 2024
11cf82f
Merge branch 'main' into tpu-peak-mem
WoosukKwon Oct 17, 2024
e016e38
stash
robertgshaw2-neuralmagic Oct 20, 2024
717b859
Merge branch 'main' into compressed-tensors-tpu
robertgshaw2-neuralmagic Oct 20, 2024
c848735
proper merge
robertgshaw2-neuralmagic Oct 20, 2024
1539915
add mixed precision
robertgshaw2-neuralmagic Oct 20, 2024
f00412a
format
robertgshaw2-neuralmagic Oct 20, 2024
b0a6b70
stash
robertgshaw2-neuralmagic Oct 20, 2024
e812d7e
Merge branch 'tpu-peak-mem' into compressed-tensors-tpu
robertgshaw2-neuralmagic Oct 20, 2024
764dda1
stash
robertgshaw2-neuralmagic Oct 20, 2024
87b2ae6
remove name
robertgshaw2-neuralmagic Oct 20, 2024
e813ff8
revert woosuk change
robertgshaw2-neuralmagic Oct 20, 2024
8cfaa1b
format
robertgshaw2-neuralmagic Oct 20, 2024
bbc9741
update
robertgshaw2-neuralmagic Oct 21, 2024
eb3f39e
fix nit
robertgshaw2-neuralmagic Oct 21, 2024
bb2fbe1
update
robertgshaw2-neuralmagic Oct 21, 2024
14ccb90
fix spurious
robertgshaw2-neuralmagic Oct 21, 2024
4092be2
stash branch for brittany
robertgshaw2-neuralmagic Oct 23, 2024
1aaa628
Merge branch 'main' into tpu-w8a8
robertgshaw2-neuralmagic Jan 6, 2025
48aa54b
revert
robertgshaw2-neuralmagic Jan 7, 2025
4efe915
fix
robertgshaw2-neuralmagic Jan 7, 2025
e98b79c
updated
robertgshaw2-neuralmagic Jan 7, 2025
5a89668
reduce cruft
robertgshaw2-neuralmagic Jan 7, 2025
57cbf5c
reduce cruft
robertgshaw2-neuralmagic Jan 7, 2025
3451c4d
updated
robertgshaw2-neuralmagic Jan 7, 2025
0c2e62a
update comment
robertgshaw2-neuralmagic Jan 7, 2025
172c9ca
revert spurious change
robertgshaw2-neuralmagic Jan 7, 2025
938ca81
remove cruft
robertgshaw2-neuralmagic Jan 7, 2025
9e18911
cruft reduction
robertgshaw2-neuralmagic Jan 7, 2025
5f58ec7
update docs
robertgshaw2-neuralmagic Jan 7, 2025
af9f298
added integration test
robertgshaw2-neuralmagic Jan 7, 2025
6fe2f62
updated
robertgshaw2-neuralmagic Jan 7, 2025
f2c0beb
Add bias back
robertgshaw2-neuralmagic Jan 7, 2025
8b29718
add bias support
robertgshaw2-neuralmagic Jan 7, 2025
1e2a373
updated
robertgshaw2-neuralmagic Jan 7, 2025
2a359ef
stash
robertgshaw2-neuralmagic Jan 7, 2025
f7e8975
Merge branch 'main' into remove-async-stream
robertgshaw2-neuralmagic Jan 7, 2025
0d4c3fd
fix
robertgshaw2-neuralmagic Jan 7, 2025
57340d2
update
robertgshaw2-neuralmagic Jan 7, 2025
38291d5
trigger test in CI
robertgshaw2-neuralmagic Jan 7, 2025
ead1e94
fix AZP
robertgshaw2-neuralmagic Jan 7, 2025
cea5e54
fixed!
robertgshaw2-neuralmagic Jan 7, 2025
940ddde
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
robertgshaw2-neuralmagic Jan 7, 2025
84a5b29
fix azp adju
robertgshaw2-neuralmagic Jan 7, 2025
a1d7b4a
make docker command look better on gh
robertgshaw2-neuralmagic Jan 7, 2025
2b4ecfd
remove torch warnings
robertgshaw2-neuralmagic Jan 7, 2025
186c108
stash
robertgshaw2-neuralmagic Jan 7, 2025
7e8598a
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
robertgshaw2-neuralmagic Jan 7, 2025
de773cd
fix AZP
robertgshaw2-neuralmagic Jan 7, 2025
3a53d7d
merged
robertgshaw2-neuralmagic Jan 7, 2025
0be5f69
added
robertgshaw2-neuralmagic Jan 7, 2025
cb69ba7
fix formatting
robertgshaw2-neuralmagic Jan 7, 2025
3896f6c
remove comment
robertgshaw2-neuralmagic Jan 7, 2025
33e1e13
formatted
robertgshaw2-neuralmagic Jan 7, 2025
dde72d6
add llama to ci
robertgshaw2-neuralmagic Jan 7, 2025
d7a9c93
Merge branch 'main' into tpu-w8a8
robertgshaw2-neuralmagic Jan 7, 2025
db9f795
Update supported_hardware.md
robertgshaw2-neuralmagic Jan 7, 2025
09ad869
Update supported_hardware.md
robertgshaw2-neuralmagic Jan 7, 2025
b74c88a
ixed docs build
robertgshaw2-neuralmagic Jan 8, 2025
da4369e
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
robertgshaw2-neuralmagic Jan 8, 2025
5ddcac2
Merge branch 'main' into tpu-w8a8
robertgshaw2-neuralmagic Jan 8, 2025
f353c43
fix CI
robertgshaw2-neuralmagic Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .buildkite/run-tpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@ remove_docker_container
# For HF_TOKEN.
source /etc/environment
# Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
docker run --privileged --net host --shm-size=16G -it \
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu \
/bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
&& python3 -m pip install pytest \
&& python3 -m pip install lm_eval[api]==0.4.4 \
&& pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
&& python3 /workspace/vllm/examples/offline_inference_tpu.py"
2 changes: 1 addition & 1 deletion docs/source/features/quantization/supported_hardware.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ The table below shows the compatibility of various quantization implementations
- ✗
- ✗
- ✅︎
-
- ✅︎
- ✗
* - FP8 (W8A8)
- ✗
Expand Down
49 changes: 49 additions & 0 deletions tests/tpu/test_quantization_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import dataclass

import lm_eval
import pytest

TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03


@dataclass
class GSM8KAccuracyTestConfig:
model_name: str
excepted_value: float

def get_model_args(self) -> str:
return (f"pretrained={self.model_name},"
"max_model_len=4096,max_num_seqs=128,tensor_parallel_size=4")


# NOTE: Accuracy scores measured on GPUs.
ACCURACY_CONFIGS = [
GSM8KAccuracyTestConfig(
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
excepted_value=0.76), # no bias
# NOTE(rob): We cannot re-initialize VLLM in the same process for TPU,
# so only one of these tests can run in a single call to pytest. As
# a follow up, move this into the LM-EVAL section of the CI.
# GSM8KAccuracyTestConfig(
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
# excepted_value=0.66), # bias in QKV layers
]


@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):

results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(),
tasks="gsm8k",
batch_size="auto",
)

EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Set

import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_int8_linear, convert_to_channelwise)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
Expand All @@ -18,6 +17,7 @@


class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()

def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
Expand All @@ -30,74 +30,25 @@ def get_min_capability(cls) -> int:
# turing and up
return 75

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(self.logical_widths) > 1
if is_fused_module and self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
self.logical_widths)
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
else:
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
if self.input_symmetric:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
layer.input_zero_point = None
else:
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = layer.input_zero_point.to(dtype=torch.int32)
range_max = (layer.input_scale *
(int8_traits.max - azps)).max()
range_min = (layer.input_scale *
(int8_traits.min - azps)).min()

scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
layer.input_scale = Parameter(scale, requires_grad=False)

# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
layer.input_zero_point = Parameter(azp, requires_grad=False)

else:
layer.input_scale = None
layer.input_zero_point = None

# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = layer.input_zero_point * azp_adj

layer.azp_adj = azp_adj
else:
layer.azp_adj = None

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
layer.logical_widths = output_partition_sizes

scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric)

kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config)

if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)

# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
Expand Down Expand Up @@ -140,12 +91,18 @@ def create_weights(self, layer: torch.nn.Module,
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)

self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj")

# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias)
return self.kernel.apply_weights(layer, x, bias)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels import (
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import List, Optional, Type
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE for reviewer - this file is not changed, it is just moved


import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.exllama import (
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.machete import (
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple

import torch


@dataclass
class ScaledMMLinearLayerConfig:
is_channelwise: bool
is_static_input_scheme: bool
input_symmetric: bool


class ScaledMMLinearKernel(ABC):

@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
raise NotImplementedError

@classmethod
@abstractmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
raise NotImplementedError

def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
w_s_param_name: str, i_s_param_name: str,
i_zp_param_name: str, azp_adj_param_name: str) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.i_s_name = i_s_param_name
self.i_zp_name = i_zp_param_name
self.azp_adj_name = azp_adj_param_name

@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError

@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError

def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # weight
torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale,
Optional[torch.Tensor], # input_zp
Optional[torch.Tensor], # azp_adj
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.i_s_name),
getattr(layer, self.i_zp_name),
getattr(layer, self.azp_adj_name),
)
Loading
Loading