Skip to content

Commit

Permalink
bump to vllm0.6.2 add explicit chat template (kserve#3964)
Browse files Browse the repository at this point in the history
* explicitly give a chat template

Signed-off-by: yxia216 <[email protected]>

* fix dummy model issue, fix python version smaller than 3.10, and formatting

Signed-off-by: yxia216 <[email protected]>

* fix vLLMModel

Signed-off-by: yxia216 <[email protected]>

* change the interface of CreateChatCompletionRequest

Signed-off-by: yxia216 <[email protected]>

* update dummy model's para

Signed-off-by: yxia216 <[email protected]>

* consitent with OpenAIGPTTokenizer and OpenAIGPTModel

Signed-off-by: yxia216 <[email protected]>

* give a chat template if there is no

Signed-off-by: yxia216 <[email protected]>

* update the response and update the readme

Signed-off-by: yxia216 <[email protected]>

* update the chat_template

Signed-off-by: yxia216 <[email protected]>

* update data

Signed-off-by: yxia216 <[email protected]>

* add test of chat temmplate for tokenizer

Signed-off-by: yxia216 <[email protected]>

* jinja2 template format

Signed-off-by: yxia216 <[email protected]>

* use a simpler chat template

---------

Signed-off-by: yxia216 <[email protected]>
  • Loading branch information
hustxiayang authored and bai.lu committed Dec 10, 2024
1 parent dec81bd commit 2ac3565
Show file tree
Hide file tree
Showing 13 changed files with 691 additions and 614 deletions.
2 changes: 1 addition & 1 deletion python/huggingface_server.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ARG POETRY_HOME=/opt/poetry
ARG POETRY_VERSION=1.8.3

# Install vllm
ARG VLLM_VERSION=0.6.1.post2
ARG VLLM_VERSION=0.6.2

RUN apt-get update -y && apt-get install gcc python3.10-venv python3-dev -y && apt-get clean && \
rm -rf /var/lib/apt/lists/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def build_generation_config(
return GenerationConfig(**kwargs)

def apply_chat_template(
self, messages: Iterable[ChatCompletionRequestMessage]
self,
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
Expand All @@ -394,6 +396,7 @@ def apply_chat_template(
str,
self._tokenizer.apply_chat_template(
[m.model_dump() for m in messages],
chat_template=chat_template,
tokenize=False,
add_generation_prompt=True,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,13 @@ def request_output_to_completion_response(
def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage,],
chat_template: Optional[str] = None,
):
return self.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, add_generation_prompt=True
conversation=messages,
chat_template=chat_template,
tokenize=False,
add_generation_prompt=True,
)

async def _post_init(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,15 @@ async def healthy(self) -> bool:
def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage,],
chat_template: Optional[str] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
"""
return ChatPrompt(
prompt=self.openai_serving_completion.apply_chat_template(messages)
prompt=self.openai_serving_completion.apply_chat_template(
messages, chat_template
)
)

async def create_completion(
Expand Down
1,161 changes: 565 additions & 596 deletions python/huggingfaceserver/poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions python/huggingfaceserver/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ packages = [
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
kserve = { path = "../kserve", extras = ["storage"], develop = true }
transformers = "~4.43.3"
transformers = ">=4.45.0"
accelerate = "~0.32.0"
torch = "~2.4.0"
vllm = { version = "^0.6.1.post2", optional = true }
setuptools = {version = "^70.0.0", python = "3.12"} # setuptools is not part of python 3.12
vllm = { version = "^0.6.2", optional = true }
setuptools = {version = ">=70.0.0", python = "3.12"} # setuptools is not part of python 3.12

[tool.poetry.extras]
vllm = [
Expand Down
10 changes: 8 additions & 2 deletions python/huggingfaceserver/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,9 @@ async def test_bloom_chat_completion(bloom_model: HuggingfaceGenerativeModel):
messages=messages,
stream=False,
max_tokens=20,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(params=params, context={})
response = await bloom_model.create_chat_completion(request)
Expand Down Expand Up @@ -416,6 +419,9 @@ async def test_bloom_chat_completion_streaming(bloom_model: HuggingfaceGenerativ
messages=messages,
stream=True,
max_tokens=20,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(params=params, context={})
response = await bloom_model.create_chat_completion(request)
Expand Down Expand Up @@ -498,6 +504,6 @@ async def test_input_padding_with_pad_token_not_specified(
response = await openai_gpt_model.create_completion(request)
assert (
response.choices[0].text
== "west, and the sun sets in the west. \n the sun rises in the"
== "west , and the sun sets in the west . \n the sun rises in the"
)
assert "a member of the royal family." in response.choices[1].text
assert "a member of the royal family ." in response.choices[1].text
75 changes: 74 additions & 1 deletion python/huggingfaceserver/tests/test_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from huggingfaceserver.vllm.vllm_completions import OpenAIServingCompletion
from huggingfaceserver.vllm.vllm_model import VLLMModel
from kserve.logging import logger
from kserve.protocol.rest.openai import ChatCompletionRequest, CompletionRequest
from kserve.protocol.rest.openai import (
ChatCompletionRequest,
CompletionRequest,
ChatPrompt,
)
from kserve.protocol.rest.openai.errors import OpenAIError
from kserve.protocol.rest.openai.types import (
CreateChatCompletionRequest,
Expand Down Expand Up @@ -98,6 +102,54 @@ def mock_load(self) -> bool:
mp.undo()


def compare_chatprompt_to_expected(actual, expected, fields_to_compare=None) -> bool:
if fields_to_compare is None:
fields_to_compare = [
"response_role",
"prompt",
]
for field in fields_to_compare:
if not getattr(actual, field) == getattr(expected, field):
logger.error(
"expected: %s\n got: %s",
getattr(expected, field),
getattr(actual, field),
)
return False
return True


@pytest.mark.asyncio()
class TestChatTemplate:
async def test_vllm_chat_completion_tokenization_facebook_opt_model(
self, vllm_opt_model
):
opt_model, _ = vllm_opt_model

messages = [
{
"role": "system",
"content": "You are a friendly chatbot who always responds in the style of a pirate",
},
{
"role": "user",
"content": "How many helicopters can a human eat in one sitting?",
},
]
chat_template = (
"{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}"
)
response = opt_model.apply_chat_template(messages, chat_template)

expected = ChatPrompt(
response_role="assistant",
prompt="You are a friendly chatbot who always responds in the style of a pirate</s>How many helicopters can a human eat in one sitting?</s>",
)
assert compare_chatprompt_to_expected(response, expected) is True


def compare_response_to_expected(actual, expected, fields_to_compare=None) -> bool:
if fields_to_compare is None:
fields_to_compare = [
Expand Down Expand Up @@ -160,6 +212,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
messages=messages,
stream=False,
max_tokens=10,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(params=params, context={})
response = await opt_model.create_chat_completion(request)
Expand Down Expand Up @@ -216,6 +271,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
messages=messages,
stream=False,
max_tokens=10,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
Expand Down Expand Up @@ -275,6 +333,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
messages=messages,
stream=True,
max_tokens=10,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
Expand Down Expand Up @@ -325,6 +386,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
max_tokens=10,
log_probs=True,
top_logprobs=2,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
Expand Down Expand Up @@ -635,6 +699,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
max_tokens=10,
log_probs=True,
top_logprobs=2,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
Expand Down Expand Up @@ -890,6 +957,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
messages=messages,
stream=True,
max_tokens=2048,
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
Expand Down Expand Up @@ -926,6 +996,9 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
stream=False,
max_tokens=10,
logit_bias={"1527": 50, "27449": 100},
chat_template="{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% endfor %}",
)
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import abstractmethod
from typing import AsyncIterator, Iterable, Union, cast
from typing import AsyncIterator, Iterable, Union, cast, Optional

from kserve.protocol.rest.openai.types import (
ChatCompletion,
Expand Down Expand Up @@ -53,7 +53,9 @@ class OpenAIChatAdapterModel(OpenAIModel):

@abstractmethod
def apply_chat_template(
self, messages: Iterable[ChatCompletionRequestMessage]
self,
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
Expand Down Expand Up @@ -193,15 +195,16 @@ def completion_to_chat_completion_chunk(
)

async def create_chat_completion(
self, request: ChatCompletionRequest
self,
request: ChatCompletionRequest,
) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
params = request.params

if params.n != 1:
raise InvalidInput("n != 1 is not supported")

# Convert the messages into a prompt
chat_prompt = self.apply_chat_template(params.messages)
chat_prompt = self.apply_chat_template(params.messages, params.chat_template)
# Translate the chat completion request to a completion request
completion_params = self.chat_completion_params_to_completion_params(
params, chat_prompt.prompt
Expand Down
5 changes: 3 additions & 2 deletions python/kserve/kserve/protocol/rest/openai/types/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Steps to generate


```bash
curl https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml -o openapi-2.0.0.yaml
datamodel-codegen --input openapi-2.0.0.yaml --input-file-type openapi --output openapi.py --output-model-type pydantic_v2.BaseModel --use-double-quotes --collapse-root-models --enum-field-as-literal all --strict-nullable
datamodel-codegen --input openapi-2.0.0.yaml --input-file-type openapi --output openapi.py --output-model-type pydantic_v2.BaseModel --use-double-quotes --collapse-root-models --enum-field-as-literal all --strict-nullable```
Adapted from the generated `openapi.py`
13 changes: 13 additions & 0 deletions python/kserve/kserve/protocol/rest/openai/types/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2705,6 +2705,19 @@ class CreateChatCompletionRequest(BaseModel):
max_length=128,
min_length=1,
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)


class RunStepObject(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion python/kserve/test/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncIterator, Callable, Iterable, List, Tuple, Union, cast
from typing import AsyncIterator, Callable, Iterable, List, Tuple, Union, cast, Optional
from unittest.mock import MagicMock, patch

import httpx
Expand Down Expand Up @@ -86,6 +86,7 @@ async def create_completion(
def apply_chat_template(
self,
messages: Iterable[ChatCompletionRequestMessage],
chat_template: Optional[str] = None,
) -> ChatPrompt:
return ChatPrompt(prompt="hello")

Expand Down
3 changes: 2 additions & 1 deletion test/e2e/data/opt_125m_input_generate.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
}
],
"temperature": 0,
"max_tokens": 20
"max_tokens": 20,
"chat_template": "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}"
}

0 comments on commit 2ac3565

Please sign in to comment.