Skip to content

Commit

Permalink
Finishing nits + integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Sep 25, 2023
1 parent c35f39c commit 8ee9307
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 39 deletions.
61 changes: 61 additions & 0 deletions integration-tests/models/test_flash_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest


@pytest.fixture(scope="module")
def flash_llama_gptq_handle(launcher):
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_llama_gptq(flash_llama_gptq_handle):
await flash_llama_gptq_handle.health(300)
return flash_llama_gptq_handle.client


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
response = await flash_llama_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)

assert response.details.generated_tokens == 10
assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
response = await flash_llama_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)

assert response.details.generated_tokens == 10
assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_load(
flash_llama_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_gptq, "Test request", max_new_tokens=10, n=4
)

assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])

assert responses == response_snapshot
29 changes: 12 additions & 17 deletions server/text_generation_server/utils/awq/quantize/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import awq_inference_engine # with CUDA kernels


class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)

def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
# class ScaledActivation(nn.Module):
# def __init__(self, module, scales):
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)


class WQLinear(nn.Module):
Expand All @@ -32,11 +32,11 @@ def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0

self.register_buffer('qweight', qweight)
self.register_buffer('qzeros', qzeros)
self.register_buffer('scales', scales)
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
if bias:
self.register_buffer('bias', bias)
self.bias = bias
else:
self.bias = None

Expand All @@ -46,8 +46,3 @@ def forward(self, x):
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
36 changes: 14 additions & 22 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,16 @@ def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
try:
qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError:
if quantize == "gptq":
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
else:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)

qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
try:
if quantize == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
except RuntimeError:
else:
g_idx = None

bits, groupsize = self._get_gptq_params()
Expand Down Expand Up @@ -185,14 +180,9 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
if quantize == "gptq":
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
else:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)

qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
Expand All @@ -201,12 +191,12 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)

try:
if quantize == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
except RuntimeError:
else:
g_idx = None

bits, groupsize = self._get_gptq_params()
Expand All @@ -233,7 +223,7 @@ def get_tensor_shard(self, var, dim):
return tensor

def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize in "gptq":
if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_params()

Expand Down Expand Up @@ -311,8 +301,10 @@ def get_multi_weights_row(self, prefix: str, quantize: str):

qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = None
use_exllama = False

weight = (qweight, qzeros, scales, None, bits, groupsize, None)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
Expand Down

0 comments on commit 8ee9307

Please sign in to comment.