Skip to content

Commit

Permalink
Move quantized weight handling out of the Weights class
Browse files Browse the repository at this point in the history
Quantized weights were loaded in the `Weights` class, but this was
getting quite unwieldy, where every higher level method to load weights
was a long conditional to cover all the different quantizers.

This change moves loading of quantized weights out of the `Weights`
class. This is done by defining a simple `WeightsLoader` interface
that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`,
and `MarlinWeightsLoader`. These implementations are in the quantizers'
respective modules. The `Weights` class provides the low-level load
operations (such as loading tensors or sharded tensors), but delegates
loads that need quantizer-specific weight processing to a loader. The
loaders still use the low-level functionality provided by `Weights`.

I initially tried making a hierarchy where a class like `GPTQWeights`
would inherit from `Weights`. But it is not very flexible (e.g. does
not work well with the new weight storage mock used in tests) and
the implicit indirections made the code harder to follow.
  • Loading branch information
danieldk committed Jul 8, 2024
1 parent 5c7c9f1 commit eca19d4
Show file tree
Hide file tree
Showing 17 changed files with 817 additions and 707 deletions.
128 changes: 70 additions & 58 deletions server/tests/utils/test_weights.py

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions server/text_generation_server/layers/exl2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
from typing import List, Union
from dataclasses import dataclass

from text_generation_server.utils.weights import WeightsLoader, Weights


@dataclass
class Exl2Weight:
Expand All @@ -21,3 +24,58 @@ def __post_init__(self):
@property
def device(self) -> torch.device:
return self.q_weight.device


class Exl2WeightsLoader(WeightsLoader):
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")

def get_weights_col(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)

q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")

return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)

def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")

def get_multi_weights_row(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)

q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")

return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
Loading

0 comments on commit eca19d4

Please sign in to comment.