Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add quant to mixtral #1337

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,6 @@ def __init__(self, config, weights):
weights=weights,
)
self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")

def forward(
self,
Expand All @@ -454,7 +452,7 @@ def forward(
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
else:
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(self.max_past, max_s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ def __init__(self, prefix, config: MixtralConfig, weights):
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)

# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)

self.offsets = None
self.offsets_block_rows = 0
Expand Down Expand Up @@ -467,8 +467,7 @@ def indices_and_padded_bins(self, selected_experts: torch.Tensor):

return indices, bin_ids, bins, padded_bins, tokens_per_expert

@torch.inference_mode()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
Expand Down Expand Up @@ -502,8 +501,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1, topo).data)
* stk.ops.sdd(x, self.w3, topo).data,
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
* stk.ops.sdd(x, self.w3.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
Expand Down Expand Up @@ -534,6 +533,156 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return x.view(*input_shape)

def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])

# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)

if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)

# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)

# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])

# Permute to [num_experts, model_dim, ffn_dim]
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)

inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)

out = torch.bmm(
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
)
# Mask not selected experts
out *= weights.t().view(self.num_experts, -1, 1)

# Sum experts
out = out.sum(0)

# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)

return out

def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256:
return self.sparse_forward(x)
# This is faster when there is not a lot of tokens
return self.dense_forward(x)


class DenseMoE(nn.Module):
def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok

act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]

# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)

self.w1 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w3 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w2 = [
TensorParallelRowLinear.load(
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
)
for i in range(self.num_experts)
]

self.process_group = weights.process_group

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])

# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)

if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)

# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)

# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.w3[i](x)
h = self.w2[i](h, reduce=False)
# Add expert output to out with masking
out += h * weights[:, i].view(-1, 1)

# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)

return out


class MixtralLayer(nn.Module):
def __init__(self, layer_id, config, weights):
Expand All @@ -543,9 +692,9 @@ def __init__(self, layer_id, config, weights):
self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.block_sparse_moe = BlockSparseMoE(
f"{prefix}.block_sparse_moe", config, weights
)

moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)

self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
Expand Down Expand Up @@ -591,9 +740,9 @@ def forward(
attn_output, res
)

block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
moe_output = self.moe(normed_attn_res_output)

return block_sparse_moe_output, attn_res
return moe_output, attn_res


class MixtralModel(torch.nn.Module):
Expand Down Expand Up @@ -675,8 +824,6 @@ def __init__(self, config, weights):
weights=weights,
)
self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")

def forward(
self,
Expand All @@ -695,7 +842,7 @@ def forward(
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
else:
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(self.max_past, max_s)
Expand Down
34 changes: 20 additions & 14 deletions server/text_generation_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def from_pb(
total_tokens = input_length + max_new_tokens - 1 + speculative_length

# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min(
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS
)
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
if SLIDING_WINDOW_BLOCKS is not None:
needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS)
blocks += needed_blocks

needed_blocks_slots.append((needed_blocks, total_tokens))
Expand All @@ -152,12 +152,13 @@ def from_pb(
slot_indices.append(request_slot_indices)

# Create tensor to slice into the kv tensor in prefill
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - SLIDING_WINDOW),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
if SLIDING_WINDOW is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - SLIDING_WINDOW),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)

all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
Expand Down Expand Up @@ -209,20 +210,24 @@ def from_pb(
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
prefill_cache_indices = torch.cat(prefill_cache_indices)
if SLIDING_WINDOW is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
prefill_cache_indices = prefill_cache_indices[0]
if SLIDING_WINDOW is not None:
prefill_cache_indices = prefill_cache_indices[0]

cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)

position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = prefill_cache_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
Expand Down Expand Up @@ -314,8 +319,9 @@ def __init__(
config.quantize = quantize

# Set context windows
SLIDING_WINDOW = config.sliding_window
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
if config.sliding_window is not None:
SLIDING_WINDOW = config.sliding_window
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)

torch.distributed.barrier(group=self.process_group)

Expand Down
6 changes: 2 additions & 4 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@
except ImportError:
pass

from typing import Optional

HAS_EETQ = False
try:
from EETQ import quant_weights, w8_a16_gemm
Expand Down Expand Up @@ -489,9 +487,9 @@ def load(cls, config, prefix: str, weights, bias: bool):
process_group=weights.process_group,
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1:
if self.process_group.size() > 1 and reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out

Expand Down
Loading