Skip to content

Commit

Permalink
Fix no speculation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Dec 5, 2023
1 parent 9bf31fe commit fdef00c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None

start_slots = torch.tensor(start_slots, dtype=torch.int64)

Expand Down Expand Up @@ -483,7 +483,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
speculative_length = b.speculative_ids.shape[1]
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
Expand Down Expand Up @@ -589,7 +589,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
device=batches[0].next_token_chooser.device,
)

speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None

# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
Expand Down Expand Up @@ -980,7 +980,7 @@ def generate_token(

if stop:
stopped = True
left = n_accepted_ids - 1 - j
left = index + n_accepted_ids - j - 1
break
else:
stopped = False
Expand Down

0 comments on commit fdef00c

Please sign in to comment.