Skip to content

Commit

Permalink
Move quantized weight handling out of the Weights class
Browse files Browse the repository at this point in the history
Quantized weights were loaded in the `Weights` class, but this was
getting quite unwieldy, where every higher level method to load weights
was a long conditional to cover all the different quantizers.

This change moves loading of quantized weights out of the `Weights`
class. This is done by defining a simple `WeightsLoader` interface
that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`,
and `MarlinWeightsLoader`. These implementations are in the quantizers'
respective modules. The `Weights` class provides the low-level load
operations (such as loading tensors or sharded tensors), but delegates
loads that need quantizer-specific weight processing to a loader. The
loaders still use the low-level functionality provided by `Weights`.

I initially tried making a hierarchy where a class like `GPTQWeights`
would inherit from `Weights`. But it is not very flexible (e.g. does
not work well with the new weight storage mock used in tests) and
the implicit indirections made the code harder to follow.
  • Loading branch information
danieldk committed Jul 9, 2024
1 parent 4c976fb commit 800ee99
Show file tree
Hide file tree
Showing 24 changed files with 896 additions and 731 deletions.
8 changes: 7 additions & 1 deletion server/tests/utils/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from text_generation_server.layers import (
TensorParallelEmbedding,
)
from text_generation_server.utils.weights import DefaultWeightsLoader


class ProcessGroup:
Expand Down Expand Up @@ -42,7 +43,12 @@ def get_shape(self, name: str):
def test_weight_hub_files_offline_error():

vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights)

input_ids = torch.arange(vocab_size)
Expand Down
Loading

0 comments on commit 800ee99

Please sign in to comment.