Skip to content

Commit

Permalink
Factor out sharding of packed tensors
Browse files Browse the repository at this point in the history
For Phi-3-Small I need to shard a packed QKV bias tensor, for which
I implemented the `Weights.get_packed_sharded` method. However, this
method can also replace the `Weights._get_qweight` method and the
custom sharding code from `Weights.get_weights_col_packed`.
  • Loading branch information
danieldk committed Jun 20, 2024
1 parent f5a9837 commit 8fc4b84
Showing 1 changed file with 60 additions and 39 deletions.
99 changes: 60 additions & 39 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,57 @@ def get_sharded(self, tensor_name: str, dim: int):
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)

def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
def get_packed_sharded(
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
) -> torch.Tensor:
"""
Get a shard from a tensor that packs multiple tensors.
When a tensor packs multiple tensors (such as QKV or an up
projection + gate projection), sharding with `get_sharded` is not
safe since it would not split the packed tensors across shards.
This method shards a tensor, such that the packed tensors are
split across shards.
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
slice_ = self._get_slice(tensor_name)
total_size = slice_.get_shape()[dim]
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

world_size = self.process_group.size()
rank = self.process_group.rank()

weights = []
tensors = []
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
), f"Prepacked tensor cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
weights.append(slice_[:, block_offset + start : block_offset + stop])
if dim == 0:
tensor = slice_[block_offset + start : block_offset + stop]
elif dim == 1:
tensor = slice_[:, block_offset + start : block_offset + stop]
else:
raise NotImplementedError("Currently only dim=0 or dim=1 is supported")
tensors.append(tensor)
block_offset += block_size
tensor = torch.cat(tensors, dim=dim)
tensor = tensor.to(device=self.device)

weight = torch.cat(weights, dim=1)
weight = weight.to(device=self.device)
return weight
# Avoid casting quantizer dtypes.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)

return tensor

def get_weights_col_packed_qkv(
self,
Expand Down Expand Up @@ -185,16 +213,22 @@ def get_weights_col_packed(
from text_generation_server.layers.gptq import GPTQWeight

try:
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
qweight = self.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)

gptq_params = self._get_gptq_params()

qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=self.dtype)

if quantize == "gptq" and gptq_params.quant_method == "gptq":
Expand Down Expand Up @@ -237,13 +271,17 @@ def get_weights_col_packed(
if quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
qweight = self.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)

scales = self._get_qweight(f"{prefix}.scales", block_sizes)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
g_idx = self.get_tensor(f"{prefix}.g_idx")
weight = repack_gptq_for_marlin(
qweight=qweight,
Expand All @@ -257,34 +295,17 @@ def get_weights_col_packed(
)

else:
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
B = self.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = self.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
block_sizes = _blocks_to_block_sizes(
total_size=total_size, blocks=block_sizes
weight = self.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)

world_size = self.process_group.size()
rank = self.process_group.rank()

tensors = []
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked weights cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensor = slice_[block_offset + start : block_offset + stop]
tensors.append(tensor)
block_offset += block_size
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
return weight

def get_weights_col(self, prefix: str, quantize: str):
Expand Down

0 comments on commit 8fc4b84

Please sign in to comment.