diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 5a4f5be0631..b97866f77d7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -434,8 +434,6 @@ def __init__(self, config, weights): weights=weights, ) self.max_past = config.sliding_window - if self.max_past is None: - raise ValueError("max_past cannot be None") def forward( self, @@ -454,7 +452,7 @@ def forward( if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - else: + elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values max_s = min(self.max_past, max_s) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 76ebc6b8f67..ff2ed9fdfdd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -365,9 +365,9 @@ def __init__(self, prefix, config: MixtralConfig, weights): self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t() + self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights) self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights) - self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t() + self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights) self.offsets = None self.offsets_block_rows = 0 @@ -467,8 +467,7 @@ def indices_and_padded_bins(self, selected_experts: torch.Tensor): return indices, bin_ids, bins, padded_bins, tokens_per_expert - @torch.inference_mode() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (sequence_length, model_dim) gate_logits: (sequence_length, n_experts) @@ -502,8 +501,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # (top_k * sequence_length + padding, ffn_dim * n_experts) x = stk.Matrix( topo.size(), - self.act(stk.ops.sdd(x, self.w1, topo).data) - * stk.ops.sdd(x, self.w3, topo).data, + self.act(stk.ops.sdd(x, self.w1.t(), topo).data) + * stk.ops.sdd(x, self.w3.t(), topo).data, topo.row_indices, topo.column_indices, topo.offsets, @@ -534,6 +533,156 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.view(*input_shape) + def dense_forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + # gate_logits: (sequence_length, n_experts) + gate_logits = self.gate(x) + # all_probs: (sequence_length, n_experts) and upcast for softmax + all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + + if self.top_k < self.num_experts: + _, not_selected_experts = torch.topk( + all_probs, + self.num_experts - self.top_k, + largest=False, + sorted=False, + dim=1, + ) + # Mask not selected experts + all_probs.scatter_(1, not_selected_experts, 0) + + # Re-normalize + weights = all_probs / all_probs.sum(dim=1, keepdim=True) + + # Expand to [num_experts, sequence_length, model_dim] + x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) + + # Permute to [num_experts, model_dim, ffn_dim] + w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( + 0, 2, 1 + ) + w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( + 0, 2, 1 + ) + + inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3) + + out = torch.bmm( + inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim) + ) + # Mask not selected experts + out *= weights.t().view(self.num_experts, -1, 1) + + # Sum experts + out = out.sum(0) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(x) > 256: + return self.sparse_forward(x) + # This is faster when there is not a lot of tokens + return self.dense_forward(x) + + +class DenseMoE(nn.Module): + def __init__(self, prefix, config: MixtralConfig, weights): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size // weights.process_group.size() + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + act = config.hidden_act + if "gelu" in act: + self.act = lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + elif "silu" in act: + self.act = torch.nn.functional.silu + else: + self.act = ACT2FN[act] + + # gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.w1 = [ + TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False + ) + for i in range(self.num_experts) + ] + self.w3 = [ + TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False + ) + for i in range(self.num_experts) + ] + self.w2 = [ + TensorParallelRowLinear.load( + config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False + ) + for i in range(self.num_experts) + ] + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + # gate_logits: (sequence_length, n_experts) + gate_logits = self.gate(x) + # all_probs: (sequence_length, n_experts) and upcast for softmax + all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + + if self.top_k < self.num_experts: + _, not_selected_experts = torch.topk( + all_probs, + self.num_experts - self.top_k, + largest=False, + sorted=False, + dim=1, + ) + # Mask not selected experts + all_probs.scatter_(1, not_selected_experts, 0) + + # Re-normalize + weights = all_probs / all_probs.sum(dim=1, keepdim=True) + + # Final output tensor + out = x.new_zeros(x.shape[0], self.hidden_dim) + for i in range(self.num_experts): + h = self.act(self.w1[i](x)) * self.w3[i](x) + h = self.w2[i](h, reduce=False) + # Add expert output to out with masking + out += h * weights[:, i].view(-1, 1) + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + class MixtralLayer(nn.Module): def __init__(self, layer_id, config, weights): @@ -543,9 +692,9 @@ def __init__(self, layer_id, config, weights): self.self_attn = MixtralAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) - self.block_sparse_moe = BlockSparseMoE( - f"{prefix}.block_sparse_moe", config, weights - ) + + moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -591,9 +740,9 @@ def forward( attn_output, res ) - block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output) + moe_output = self.moe(normed_attn_res_output) - return block_sparse_moe_output, attn_res + return moe_output, attn_res class MixtralModel(torch.nn.Module): @@ -675,8 +824,6 @@ def __init__(self, config, weights): weights=weights, ) self.max_past = config.sliding_window - if self.max_past is None: - raise ValueError("max_past cannot be None") def forward( self, @@ -695,7 +842,7 @@ def forward( if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - else: + elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values max_s = min(self.max_past, max_s) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0fad5aa8205..abe07c305f7 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -136,9 +136,9 @@ def from_pb( total_tokens = input_length + max_new_tokens - 1 + speculative_length # Needed blocks can not go over SLIDING_WINDOW_BLOCKS - needed_blocks = min( - math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS - ) + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + if SLIDING_WINDOW_BLOCKS is not None: + needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) @@ -152,12 +152,13 @@ def from_pb( slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - SLIDING_WINDOW), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) + if SLIDING_WINDOW is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - SLIDING_WINDOW), + cumulative_length + input_length, + dtype=torch.int64, + ) + prefill_cache_indices.append(request_prefill_cache_indices) all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs @@ -209,12 +210,14 @@ def from_pb( input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) - prefill_cache_indices = torch.cat(prefill_cache_indices) + if SLIDING_WINDOW is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] - prefill_cache_indices = prefill_cache_indices[0] + if SLIDING_WINDOW is not None: + prefill_cache_indices = prefill_cache_indices[0] cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 @@ -222,7 +225,9 @@ def from_pb( position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) - prefill_cache_indices = prefill_cache_indices.to(device) + prefill_cache_indices = ( + prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None + ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( input_lengths, dtype=torch.int32, device=device @@ -314,8 +319,9 @@ def __init__( config.quantize = quantize # Set context windows - SLIDING_WINDOW = config.sliding_window - SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) + if config.sliding_window is not None: + SLIDING_WINDOW = config.sliding_window + SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 77e2fdb6cf4..011a9382cc7 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -64,8 +64,6 @@ except ImportError: pass -from typing import Optional - HAS_EETQ = False try: from EETQ import quant_weights, w8_a16_gemm @@ -489,9 +487,9 @@ def load(cls, config, prefix: str, weights, bias: bool): process_group=weights.process_group, ) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) - if self.process_group.size() > 1: + if self.process_group.size() > 1 and reduce: torch.distributed.all_reduce(out, group=self.process_group) return out