Skip to content

Commit

Permalink
Merge squash from allowed_token_ids branch
Browse files Browse the repository at this point in the history
Signed-off-by: Jefferson Fialho <[email protected]>
  • Loading branch information
fialhocoelho committed Jul 23, 2024
1 parent f25be18 commit c4fc889
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 225 deletions.
2 changes: 1 addition & 1 deletion .buildkite/check-wheel-size.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import zipfile

MAX_SIZE_MB = 200
MAX_SIZE_MB = 250


def print_top_10_largest_files(zip_file):
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers >= 4.42.4 # Required for Gemma 2 and for additional chat template parameters.
transformers >= 4.43.1 # Required for Chameleon and Llama 3.1 hotfox.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
aiohttp
Expand Down
53 changes: 27 additions & 26 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ def test_get_sliding_window():


def test_rope_customization():
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}

llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
Expand Down Expand Up @@ -96,27 +95,29 @@ def test_rope_customization():
None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384

longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == LONGCHAT_ROPE_SCALING
assert longchat_model_config.max_model_len == 16384

longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert longchat_model_config.max_model_len == 4096
# TODO: add these back when the rope configs are fixed
# LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
# longchat_model_config = ModelConfig(
# "lmsys/longchat-13b-16k",
# "lmsys/longchat-13b-16k",
# tokenizer_mode="auto",
# trust_remote_code=False,
# dtype="float16",
# seed=0,
# )
# assert getattr(longchat_model_config.hf_config, "rope_scaling",
# None) == LONGCHAT_ROPE_SCALING
# assert longchat_model_config.max_model_len == 16384

# longchat_model_config = ModelConfig(
# "lmsys/longchat-13b-16k",
# "lmsys/longchat-13b-16k",
# tokenizer_mode="auto",
# trust_remote_code=False,
# dtype="float16",
# seed=0,
# rope_scaling=TEST_ROPE_SCALING,
# )
# assert getattr(longchat_model_config.hf_config, "rope_scaling",
# None) == TEST_ROPE_SCALING
# assert longchat_model_config.max_model_len == 4096
75 changes: 75 additions & 0 deletions vllm/entrypoints/openai/logits_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import lru_cache
from typing import Dict, FrozenSet, Iterable, List, Optional, Union

import torch
from transformers import PreTrainedTokenizer

from vllm.sampling_params import LogitsProcessor


class AllowedTokenIdsLogitsProcessor:
"""Logits processor for constraining generatedtokens to a
specific set of token ids."""

def __init__(self, allowed_ids: Iterable[int]):
self.allowed_ids: Optional[List[int]] = list(allowed_ids)
self.mask: Optional[torch.Tensor] = None

def __call__(self, token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
if self.mask is None:
self.mask = torch.full((logits.shape[-1], ),
True,
dtype=torch.bool,
device=logits.device)
self.mask[self.allowed_ids] = False
self.allowed_ids = None
logits.masked_fill_(self.mask, float("-inf"))
return logits


@lru_cache(maxsize=32)
def _get_allowed_token_ids_logits_processor(
allowed_token_ids: FrozenSet[int],
vocab_size: int,
) -> LogitsProcessor:
if not allowed_token_ids:
raise ValueError("Empty allowed_token_ids provided")
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
raise ValueError("allowed_token_ids contains "
"out-of-vocab token id")
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)


def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
logits_processors = []
if logit_bias:
try:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias: Dict[int, float] = {
int(token_id): min(100.0, max(-100.0, bias))
for token_id, bias in logit_bias.items()
}
except ValueError as exc:
raise ValueError(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer") from exc

def logit_bias_logits_processor(token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in clamped_logit_bias.items():
logits[token_id] += bias
return logits

logits_processors.append(logit_bias_logits_processor)

if allowed_token_ids is not None:
logits_processors.append(
_get_allowed_token_ids_logits_processor(
frozenset(allowed_token_ids), tokenizer.vocab_size))

return logits_processors
61 changes: 17 additions & 44 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Annotated

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
Expand Down Expand Up @@ -213,30 +216,15 @@ class ChatCompletionRequest(OpenAIBaseModel):

# doc: end-chat-completion-extra-params

def to_sampling_params(self) -> SamplingParams:
def to_sampling_params(self,
tokenizer: PreTrainedTokenizer) -> SamplingParams:
# We now allow logprobs being true without top_logrobs.

logits_processors = None
if self.logit_bias:
logit_bias: Dict[int, float] = {}
try:
for token_id, bias in self.logit_bias.items():
# Convert token_id to integer before we add to LLMEngine
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias[int(token_id)] = min(100, max(-100, bias))
except ValueError as exc:
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
f"but token_id must be an integer or string "
f"representing an integer") from exc

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits

logits_processors = [logit_bias_logits_processor]
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=None,
tokenizer=tokenizer,
)

return SamplingParams(
n=self.n,
Expand Down Expand Up @@ -358,6 +346,7 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -407,30 +396,14 @@ class CompletionRequest(OpenAIBaseModel):

# doc: end-completion-extra-params

def to_sampling_params(self):
def to_sampling_params(self, tokenizer: PreTrainedTokenizer):
echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = None
if self.logit_bias:
logit_bias: Dict[int, float] = {}
try:
for token_id, bias in self.logit_bias.items():
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias[int(token_id)] = min(100, max(-100, bias))
except ValueError as exc:
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
f"but token_id must be an integer or string "
f"representing an integer") from exc

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits

logits_processors = [logit_bias_logits_processor]
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
tokenizer=tokenizer,
)

return SamplingParams(
n=self.n,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def create_chat_completion(

request_id = f"chat-{random_uuid()}"
try:
sampling_params = request.to_sampling_params()
sampling_params = request.to_sampling_params(tokenizer)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def create_completion(self, request: CompletionRequest,

tokenizer = await self.engine.get_tokenizer(lora_request)

sampling_params = request.to_sampling_params()
sampling_params = request.to_sampling_params(tokenizer)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
#TODO(ywang96): remove this when huggingface fixes the model repo
"ChameleonForCausalLM": ("chameleon", "ChameleonForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
from PIL import Image
from torch import nn
from transformers import ChameleonConfig, ChameleonVQVAEConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
Expand All @@ -30,8 +31,6 @@
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.transformers_utils.configs import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.utils import print_warning_once

from .interfaces import SupportsVision
Expand Down
9 changes: 4 additions & 5 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChameleonConfig, ChatGLMConfig,
DbrxConfig, JAISConfig,
MedusaConfig, MLPSpeculatorConfig,
MPTConfig, RWConfig)
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
RWConfig)

if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
Expand All @@ -18,7 +18,6 @@
logger = init_logger(__name__)

_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chameleon": ChameleonConfig,
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
Expand Down
4 changes: 0 additions & 4 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from vllm.transformers_utils.configs.chameleon import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
Expand All @@ -12,8 +10,6 @@
from vllm.transformers_utils.configs.mpt import MPTConfig

__all__ = [
"ChameleonConfig",
"ChameleonVQVAEConfig",
"ChatGLMConfig",
"DbrxConfig",
"MPTConfig",
Expand Down
Loading

0 comments on commit c4fc889

Please sign in to comment.