From 60dc62dc9e53428912953276e0d12a034b353fb6 Mon Sep 17 00:00:00 2001 From: Roy Date: Mon, 4 Dec 2023 04:59:18 +0800 Subject: [PATCH] add custom server params (#1868) --- vllm/entrypoints/openai/api_server.py | 4 ++++ vllm/entrypoints/openai/protocol.py | 4 ++++ vllm/sampling_params.py | 1 + 3 files changed, 9 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 39ea750aa9dc1..7b94e1b52a5fd 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -253,8 +253,10 @@ async def create_chat_completion(request: ChatCompletionRequest, n=request.n, presence_penalty=request.presence_penalty, frequency_penalty=request.frequency_penalty, + repetition_penalty=request.repetition_penalty, temperature=request.temperature, top_p=request.top_p, + min_p=request.min_p, stop=request.stop, stop_token_ids=request.stop_token_ids, max_tokens=request.max_tokens, @@ -497,9 +499,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request): best_of=request.best_of, presence_penalty=request.presence_penalty, frequency_penalty=request.frequency_penalty, + repetition_penalty=request.repetition_penalty, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, + min_p=request.min_p, stop=request.stop, stop_token_ids=request.stop_token_ids, ignore_eos=request.ignore_eos, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2aa567cb87034..7a86a19c4bf80 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -75,6 +75,8 @@ class ChatCompletionRequest(BaseModel): spaces_between_special_tokens: Optional[bool] = True add_generation_prompt: Optional[bool] = True echo: Optional[bool] = False + repetition_penalty: Optional[float] = 1.0 + min_p: Optional[float] = 0.0 class CompletionRequest(BaseModel): @@ -102,6 +104,8 @@ class CompletionRequest(BaseModel): stop_token_ids: Optional[List[int]] = Field(default_factory=list) skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True + repetition_penalty: Optional[float] = 1.0 + min_p: Optional[float] = 0.0 class LogProbs(BaseModel): diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f33a33d1fd64d..38b7c0b531bd2 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -149,6 +149,7 @@ def __init__( # Zero temperature means greedy sampling. self.top_p = 1.0 self.top_k = -1 + self.min_p = 0.0 self._verify_greedy_sampling() def _verify_args(self) -> None: