Skip to content

Commit

Permalink
explicitly give a chat template
Browse files Browse the repository at this point in the history
Signed-off-by: yxia216 <[email protected]>
  • Loading branch information
hustxiayang committed Sep 30, 2024
1 parent 6289d40 commit 0299c07
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def build_generation_config(
return GenerationConfig(**kwargs)

def apply_chat_template(
self, messages: Iterable[ChatCompletionRequestMessage]
self, messages: Iterable[ChatCompletionRequestMessage], chat_template: Optional[str] = None,
) -> ChatPrompt:
"""
Given a list of chat completion messages, convert them to a prompt.
Expand All @@ -394,6 +394,7 @@ def apply_chat_template(
str,
self._tokenizer.apply_chat_template(
[m.model_dump() for m in messages],
chat_template=chat_template,
tokenize=False,
add_generation_prompt=True,
),
Expand Down
6 changes: 4 additions & 2 deletions python/huggingfaceserver/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ async def test_bloom_chat_completion(bloom_model: HuggingfaceGenerativeModel):
max_tokens=20,
)
request = ChatCompletionRequest(params=params, context={})
response = await bloom_model.create_chat_completion(request)
chat_template = "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
response = await bloom_model.create_chat_completion(request, chat_template)
assert (
response.choices[0].message.content
== "The first thing you need to do is to get a good idea of what you are looking for."
Expand All @@ -417,7 +418,8 @@ async def test_bloom_chat_completion_streaming(bloom_model: HuggingfaceGenerativ
max_tokens=20,
)
request = ChatCompletionRequest(params=params, context={})
response = await bloom_model.create_chat_completion(request)
chat_template = "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
response = await bloom_model.create_chat_completion(request, chat_template)
output = ""
async for chunk in response:
output += chunk.choices[0].delta.content
Expand Down
21 changes: 14 additions & 7 deletions python/huggingfaceserver/tests/test_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
max_tokens=10,
)
request = ChatCompletionRequest(params=params, context={})
response = await opt_model.create_chat_completion(request)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"
response = await opt_model.create_chat_completion(request, chat_template)
expected = CreateChatCompletionResponse(
id=request_id,
choices=[
Expand Down Expand Up @@ -220,7 +221,8 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
)
response = await opt_model.create_chat_completion(request)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"
response = await opt_model.create_chat_completion(request, chat_template)
expected = CreateChatCompletionResponse(
id=request_id,
choices=[
Expand Down Expand Up @@ -279,7 +281,8 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
)
response_iterator = await opt_model.create_chat_completion(request)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"
response_iterator = await opt_model.create_chat_completion(request, chat_template)
completion = ""
async for resp in response_iterator:
assert len(resp.choices) == 1
Expand Down Expand Up @@ -329,7 +332,8 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
)
response = await opt_model.create_chat_completion(request)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"
response = await opt_model.create_chat_completion(request, chat_template)
expected = CreateChatCompletionResponse(
id=request_id,
choices=[
Expand Down Expand Up @@ -639,7 +643,8 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
)
response_iterator = await opt_model.create_chat_completion(request)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"
response_iterator = await opt_model.create_chat_completion(request, chat_template)
completion = ""
log_probs = ChatCompletionChoiceLogprobs(
content=[],
Expand Down Expand Up @@ -895,7 +900,8 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
request_id=request_id, params=params, context={}
)
with pytest.raises(OpenAIError):
await opt_model.create_chat_completion(request)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"
await opt_model.create_chat_completion(request, chat_template)

async def test_vllm_chat_completion_facebook_opt_model_with_logit_bias(
self, vllm_opt_model
Expand Down Expand Up @@ -930,7 +936,8 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:
request = ChatCompletionRequest(
request_id=request_id, params=params, context={}
)
response = await opt_model.create_chat_completion(request)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"
response = await opt_model.create_chat_completion(request, chat_template)
expected = CreateChatCompletionResponse(
id=request_id,
choices=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,15 @@ def completion_to_chat_completion_chunk(
)

async def create_chat_completion(
self, request: ChatCompletionRequest
self, request: ChatCompletionRequest, chat_template: str | None = None,
) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]:
params = request.params

if params.n != 1:
raise InvalidInput("n != 1 is not supported")

# Convert the messages into a prompt
chat_prompt = self.apply_chat_template(params.messages)
chat_prompt = self.apply_chat_template(params.messages, chat_template)
# Translate the chat completion request to a completion request
completion_params = self.chat_completion_params_to_completion_params(
params, chat_prompt.prompt
Expand Down

0 comments on commit 0299c07

Please sign in to comment.