From e25fee57c2e69161bd261f5986dc5aeb198bbd42 Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Fri, 23 Aug 2024 10:12:44 -0300 Subject: [PATCH] [BugFix] Fix server crash on empty prompt (#7746) Signed-off-by: Max de Bayser --- .../entrypoints/llm/test_prompt_validation.py | 9 ++++++++ .../openai/test_prompt_validation.py | 22 +++++++++++++++++++ vllm/engine/llm_engine.py | 8 +++++++ 3 files changed, 39 insertions(+) create mode 100644 tests/entrypoints/llm/test_prompt_validation.py create mode 100644 tests/entrypoints/openai/test_prompt_validation.py diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py new file mode 100644 index 0000000000000..565dfa01346cc --- /dev/null +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -0,0 +1,9 @@ +import pytest + +from vllm import LLM + + +def test_empty_prompt(): + llm = LLM(model="gpt2") + with pytest.raises(ValueError, match='Prompt cannot be empty'): + llm.generate([""]) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py new file mode 100644 index 0000000000000..0a573a0066d32 --- /dev/null +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -0,0 +1,22 @@ +# imports for guided decoding tests +import re + +import openai +import pytest + +from ...utils import RemoteOpenAIServer + + +@pytest.mark.asyncio +async def test_empty_prompt(): + model_name = "gpt2" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + with pytest.raises(openai.BadRequestError, + match=re.compile('.+Prompt cannot be empty.+')): + await client.completions.create(model=model_name, + prompt="", + max_tokens=5, + temperature=0.0) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f72902c372181..8c98b64181d06 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -591,6 +591,7 @@ def _add_processed_request( prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, ) -> None: + self._validate_model_inputs(processed_inputs) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -1647,3 +1648,10 @@ def is_encoder_decoder_model(self): def is_embedding_model(self): return self.model_config.is_embedding_model + + def _validate_model_inputs(self, inputs: Union[LLMInputs, + EncoderDecoderLLMInputs]): + prompt_key = "encoder_prompt_token_ids" \ + if self.is_encoder_decoder_model() else "prompt_token_ids" + if not inputs.get(prompt_key): + raise ValueError("Prompt cannot be empty")