From fdef00c27e250d58dfd452d958a003f9c8966c7a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 5 Dec 2023 22:02:22 +0000 Subject: [PATCH] Fix no speculation. --- server/text_generation_server/models/flash_causal_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5e12e67c18b..3f9c21b2040 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -427,7 +427,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] - speculative_ids = self.speculative_ids[indices] + speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -483,7 +483,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks - speculative_length = b.speculative_ids.shape[1] + speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( @@ -589,7 +589,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch device=batches[0].next_token_chooser.device, ) - speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: @@ -980,7 +980,7 @@ def generate_token( if stop: stopped = True - left = n_accepted_ids - 1 - j + left = index + n_accepted_ids - j - 1 break else: stopped = False