Skip to content

Commit

Permalink
fixed issue with adding sampled token to logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
abf149 committed Nov 15, 2024
1 parent e5ef93d commit 01b9552
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)

import torch

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -499,10 +497,15 @@ def update_from_output(
logprob_values[max_logprobs:-1] = (
[float('-inf')] *
(len(logprob_values) - 1 - max_logprobs))
logprob_values, indices = torch.sort(logprob_values,
dim=-1)
logprob_token_ids = torch.gather(
logprob_token_ids, 1, indices)

indices = sorted(range(len(logprob_values)),
key=lambda k: logprob_values[k],
reverse=True)
logprob_values = [logprob_values[i] for i in indices]
logprob_token_ids = [
logprob_token_ids[i] for i in indices
]

# There will be one more logprob than the user requested
logprob_cnt = max_logprobs + 1

Expand Down

0 comments on commit 01b9552

Please sign in to comment.