From b84cafae9b40d3446d87c3bd4b877eebb1a54623 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 17 Apr 2024 04:41:12 -0400 Subject: [PATCH] feat: accept list as prompt and use first string (#1702) This PR allows the `CompletionRequest.prompt` to be sent as a string or array of strings. When an array is sent the first value will be used if it's a string; otherwise the according error will be thrown Fixes: https://github.com/huggingface/text-generation-inference/issues/1690 Similar to: https://github.com/vllm-project/vllm/pull/323/files --- clients/python/text_generation/types.py | 21 + docs/source/basic_tutorials/launcher.md | 9 + integration-tests/conftest.py | 27 +- ...t_flash_llama_completion_many_prompts.json | 38 ++ ..._llama_completion_many_prompts_stream.json | 602 ++++++++++++++++++ ..._flash_llama_completion_single_prompt.json | 20 + .../models/test_completion_prompts.py | 109 ++++ launcher/src/main.rs | 6 + router/src/lib.rs | 33 +- router/src/main.rs | 4 + router/src/server.rs | 426 ++++++++++--- 11 files changed, 1188 insertions(+), 107 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json create mode 100644 integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json create mode 100644 integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json create mode 100644 integration-tests/models/test_completion_prompts.py diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index deb987c533c..cfa2a9ed7b3 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -59,6 +59,17 @@ class ChatCompletionComplete(BaseModel): usage: Optional[Any] = None +class CompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + text: str + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + + class Function(BaseModel): name: Optional[str] arguments: str @@ -104,6 +115,16 @@ class ChatComplete(BaseModel): usage: Any +class Completion(BaseModel): + # Completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[CompletionComplete] + + class ChatRequest(BaseModel): # Model identifier model: str diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index d9b272dbeaa..de7c995dc1c 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -398,6 +398,15 @@ Options: -e, --env Display a lot of information about your runtime environment +``` +## MAX_CLIENT_BATCH_SIZE +```shell + --max-client-batch-size + Control the maximum number of inputs that a client can send in a single request + + [env: MAX_CLIENT_BATCH_SIZE=] + [default: 4] + ``` ## HELP ```shell diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index e8ce0d842d7..cf0f498dae5 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -9,6 +9,7 @@ import math import time import random +import re from docker.errors import NotFound from typing import Optional, List, Dict @@ -26,6 +27,7 @@ ChatComplete, ChatCompletionChunk, ChatCompletionComplete, + Completion, ) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) @@ -69,17 +71,22 @@ def convert_data(data): data = json.loads(data) if isinstance(data, Dict) and "choices" in data: choices = data["choices"] - if ( - isinstance(choices, List) - and len(choices) >= 1 - and "delta" in choices[0] - ): - return ChatCompletionChunk(**data) + if isinstance(choices, List) and len(choices) >= 1: + if "delta" in choices[0]: + return ChatCompletionChunk(**data) + if "text" in choices[0]: + return Completion(**data) return ChatComplete(**data) if isinstance(data, Dict): return Response(**data) if isinstance(data, List): + if ( + len(data) > 0 + and "object" in data[0] + and data[0]["object"] == "text_completion" + ): + return [Completion(**d) for d in data] return [Response(**d) for d in data] raise NotImplementedError @@ -161,6 +168,9 @@ def eq_details(details: Details, other: Details) -> bool: ) ) + def eq_completion(response: Completion, other: Completion) -> bool: + return response.choices[0].text == other.choices[0].text + def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: return ( response.choices[0].message.content == other.choices[0].message.content @@ -184,6 +194,11 @@ def eq_response(response: Response, other: Response) -> bool: if not isinstance(snapshot_data, List): snapshot_data = [snapshot_data] + if isinstance(serialized_data[0], Completion): + return len(snapshot_data) == len(serialized_data) and all( + [eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)] + ) + if isinstance(serialized_data[0], ChatComplete): return len(snapshot_data) == len(serialized_data) and all( [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json new file mode 100644 index 00000000000..3fe0a4826d5 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 1, + "logprobs": null, + "text": " PR for more information?" + }, + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": "le Business Incubator is providing a workspace" + }, + { + "finish_reason": "length", + "index": 2, + "logprobs": null, + "text": " severely flawed and often has a substandard" + }, + { + "finish_reason": "length", + "index": 3, + "logprobs": null, + "text": "hd20220811-" + } + ], + "created": 1713284455, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native", + "usage": { + "completion_tokens": 36, + "prompt_tokens": 8, + "total_tokens": 44 + } +} diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json new file mode 100644 index 00000000000..702d48f4106 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json @@ -0,0 +1,602 @@ +[ + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "hd" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "aho" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "2" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "2" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": "2" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "ima" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "." + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "." + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": "." + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "\n" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": " Sarah" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": " Yes" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " And" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "i" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "'" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "," + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " what" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "'" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": "s" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": " Moh" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " is" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": "m" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": " Room" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "s" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " the" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": " tired" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": ":" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": "'" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " capital" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": " of" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 0, + "logprobs": null, + "text": " She" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 1, + "logprobs": null, + "text": " scale" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 2, + "logprobs": null, + "text": " of" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + }, + { + "choices": [ + { + "finish_reason": "", + "index": 3, + "logprobs": null, + "text": " being" + } + ], + "created": 1713284431, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native" + } +] diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json new file mode 100644 index 00000000000..8b001a09412 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": " PR for flake8" + } + ], + "created": 1713284454, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "2.0.0-native", + "usage": { + "completion_tokens": 5, + "prompt_tokens": 6, + "total_tokens": 11 + } +} diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py new file mode 100644 index 00000000000..cafa8ea6847 --- /dev/null +++ b/integration-tests/models/test_completion_prompts.py @@ -0,0 +1,109 @@ +import pytest +import requests +import json +from aiohttp import ClientSession + +from text_generation.types import ( + Completion, +) + + +@pytest.fixture(scope="module") +def flash_llama_completion_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_completion(flash_llama_completion_handle): + await flash_llama_completion_handle.health(300) + return flash_llama_completion_handle.client + + +# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience +# method for it. Instead, we use the `requests` library to make the HTTP request directly. + + +def test_flash_llama_completion_single_prompt( + flash_llama_completion, response_snapshot +): + response = requests.post( + f"{flash_llama_completion.base_url}/v1/completions", + json={ + "model": "tgi", + "prompt": "Say this is a test", + "max_tokens": 5, + "seed": 0, + }, + headers=flash_llama_completion.headers, + stream=False, + ) + response = response.json() + assert len(response["choices"]) == 1 + + assert response == response_snapshot + + +def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): + response = requests.post( + f"{flash_llama_completion.base_url}/v1/completions", + json={ + "model": "tgi", + "prompt": ["Say", "this", "is", "a"], + "max_tokens": 10, + "seed": 0, + }, + headers=flash_llama_completion.headers, + stream=False, + ) + response = response.json() + assert len(response["choices"]) == 4 + + all_indexes = [choice["index"] for choice in response["choices"]] + all_indexes.sort() + assert all_indexes == [0, 1, 2, 3] + + assert response == response_snapshot + + +async def test_flash_llama_completion_many_prompts_stream( + flash_llama_completion, response_snapshot +): + request = { + "model": "tgi", + "prompt": [ + "What color is the sky?", + "Is water wet?", + "What is the capital of France?", + "def mai", + ], + "max_tokens": 10, + "seed": 0, + "stream": True, + } + + url = f"{flash_llama_completion.base_url}/v1/completions" + + chunks = [] + async with ClientSession(headers=flash_llama_completion.headers) as session: + async with session.post(url, json=request) as response: + # iterate over the stream + async for chunk in response.content.iter_any(): + # remove "data:" + chunk = chunk.decode().split("\n\n") + # remove "data:" if present + chunk = [c.replace("data:", "") for c in chunk] + # remove empty strings + chunk = [c for c in chunk if c] + # parse json + chunk = [json.loads(c) for c in chunk] + + for c in chunk: + chunks.append(Completion(**c)) + assert "choices" in c + assert 0 <= c["choices"][0]["index"] <= 4 + + assert response.status == 200 + assert chunks == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2bbfc017eae..f79dbd28646 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -416,6 +416,10 @@ struct Args { /// Display a lot of information about your runtime environment #[clap(long, short, action)] env: bool, + + /// Control the maximum number of inputs that a client can send in a single request + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, } #[derive(Debug)] @@ -1088,6 +1092,8 @@ fn spawn_webserver( // Start webserver tracing::info!("Starting Webserver"); let mut router_args = vec![ + "--max-client-batch-size".to_string(), + args.max_client_batch_size.to_string(), "--max-concurrent-requests".to_string(), args.max_concurrent_requests.to_string(), "--max-best-of".to_string(), diff --git a/router/src/lib.rs b/router/src/lib.rs index ddb28848162..2395e3e29a9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -155,6 +155,8 @@ pub struct Info { pub max_batch_size: Option, #[schema(example = "2")] pub validation_workers: usize, + #[schema(example = "32")] + pub max_client_batch_size: usize, /// Router Info #[schema(example = "0.5.0")] pub version: &'static str, @@ -280,6 +282,34 @@ fn default_parameters() -> GenerateParameters { } } +mod prompt_serde { + use serde::{self, Deserialize, Deserializer}; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + match value { + Value::String(s) => Ok(vec![s]), + Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom( + "Empty array detected. Do not use an empty array for the prompt.", + )), + Value::Array(arr) => arr + .iter() + .map(|v| match v { + Value::String(s) => Ok(s.to_owned()), + _ => Err(serde::de::Error::custom("Expected a string")), + }) + .collect(), + _ => Err(serde::de::Error::custom( + "Expected a string or an array of strings", + )), + } + } +} + #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] pub struct CompletionRequest { /// UNUSED @@ -289,7 +319,8 @@ pub struct CompletionRequest { /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] - pub prompt: String, + #[serde(deserialize_with = "prompt_serde::deserialize")] + pub prompt: Vec, /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] diff --git a/router/src/main.rs b/router/src/main.rs index a224dd4a8b8..0869f41f9a5 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -81,6 +81,8 @@ struct Args { messages_api_enabled: bool, #[clap(long, env, default_value_t = false)] disable_grammar_support: bool, + #[clap(default_value = "4", long, env)] + max_client_batch_size: usize, } #[tokio::main] @@ -115,6 +117,7 @@ async fn main() -> Result<(), RouterError> { ngrok_edge, messages_api_enabled, disable_grammar_support, + max_client_batch_size, } = args; // Launch Tokio runtime @@ -422,6 +425,7 @@ async fn main() -> Result<(), RouterError> { tokenizer_config, messages_api_enabled, disable_grammar_support, + max_client_batch_size, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index f92028da3ca..79c48f4a5d3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -18,6 +18,7 @@ use crate::{ CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; use crate::{FunctionDefinition, ToolCall, ToolType}; +use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -25,8 +26,8 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; -use futures::stream::FuturesUnordered; use futures::stream::StreamExt; +use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; @@ -37,7 +38,9 @@ use std::sync::atomic::AtomicBool; use std::sync::Arc; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; +use tokio::select; use tokio::signal; +use tokio::sync::oneshot; use tokio::time::Instant; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; @@ -163,6 +166,15 @@ async fn generate( Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + generate_internal(infer, ComputeType(compute_type), Json(req), span).await +} + +async fn generate_internal( + infer: Extension, + ComputeType(compute_type): ComputeType, + Json(req): Json, + span: tracing::Span, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -361,12 +373,13 @@ async fn generate_stream( HeaderMap, Sse>>, ) { + let span = tracing::Span::current(); let on_message_callback = |stream_token: StreamResponse| { let event = Event::default(); event.json_data(stream_token).unwrap() }; let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await; + generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); (headers, sse) } @@ -376,8 +389,8 @@ async fn generate_stream_internal( ComputeType(compute_type): ComputeType, Json(req): Json, on_message_callback: impl Fn(StreamResponse) -> Event, + span: tracing::Span, ) -> (HeaderMap, impl Stream>) { - let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -583,6 +596,7 @@ async fn completions( Extension(info): Extension, Json(req): Json, ) -> Result)> { + let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); let stream = req.stream; @@ -602,100 +616,299 @@ async fn completions( )); } - // build the request passing some parameters - let generate_request = GenerateRequest { - inputs: req.prompt.to_string(), - parameters: GenerateParameters { - best_of: None, - temperature: req.temperature, - repetition_penalty: req.repetition_penalty, - frequency_penalty: req.frequency_penalty, - top_k: None, - top_p: req.top_p, - typical_p: None, - do_sample: true, - max_new_tokens, - return_full_text: None, - stop: Vec::new(), - truncate: None, - watermark: false, - details: true, - decoder_input_details: !stream, - seed, - top_n_tokens: None, - grammar: None, - }, - }; + if req.prompt.len() > info.max_client_batch_size { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: format!( + "Number of prompts exceeds the maximum allowed batch size of {}", + info.max_client_batch_size + ), + error_type: "batch size exceeded".to_string(), + }), + )); + } + + let generate_requests: Vec = req + .prompt + .iter() + .map(|prompt| GenerateRequest { + inputs: prompt.to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: req.temperature, + repetition_penalty: req.repetition_penalty, + frequency_penalty: req.frequency_penalty, + top_k: None, + top_p: req.top_p, + typical_p: None, + do_sample: true, + max_new_tokens, + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: true, + decoder_input_details: !stream, + seed, + top_n_tokens: None, + grammar: None, + }, + }) + .collect(); + + let mut x_compute_type = None; + let mut x_compute_characters = 0u32; + let mut x_accel_buffering = None; if stream { - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); + let mut response_streams = FuturesOrdered::new(); + for (index, generate_request) in generate_requests.into_iter().enumerate() { + let model_id = info.model_id.clone(); + let system_fingerprint = + format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); + let infer_clone = infer.clone(); + let compute_type_clone = compute_type.clone(); + let span_clone = span.clone(); + + // Create a future for each generate_stream_internal call. + let generate_future = async move { + let on_message_callback = move |stream_token: StreamResponse| { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + event + .json_data(CompletionCompleteChunk { + id: "".to_string(), + object: "text_completion".to_string(), + created: current_time, + + choices: vec![CompletionComplete { + finish_reason: "".to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + }) + .map_or_else(|_e| Event::default(), |data| data) + }; + + let (header_tx, header_rx) = oneshot::channel(); + let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); + + tokio::spawn(async move { + let (header_map, sse) = generate_stream_internal( + infer_clone.clone(), + compute_type_clone.clone(), + Json(generate_request), + on_message_callback, + span_clone.clone(), + ) + .await; - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); + // send and dont wait for response + let _ = header_tx.send(header_map); - event - .json_data(CompletionCompleteChunk { - id: "".to_string(), - object: "text_completion".to_string(), - created: current_time, - - choices: vec![CompletionComplete { - finish_reason: "".to_string(), - index: 0, - logprobs: None, - text: stream_token.token.text, - }], - - model: info.model_id.clone(), - system_fingerprint: format!( - "{}-{}", - info.version, - info.docker_label.unwrap_or("native") - ), - }) - .map_or_else( - |e| { - println!("Failed to serialize CompletionCompleteChunk: {:?}", e); - Event::default() - }, - |data| data, + // pin an emit messages to the sse_tx + let mut sse = Box::pin(sse); + while let Some(event) = sse.next().await { + if sse_tx.send(event).is_err() { + tracing::error!("Failed to send event. Receiver dropped."); + break; + } + } + }); + + (header_rx, sse_rx) + }; + response_streams.push_back(generate_future); + } + + let mut all_rxs = vec![]; + + while let Some((header_rx, sse_rx)) = response_streams.next().await { + all_rxs.push(sse_rx); + + // get the headers from the first response of each stream + let headers = header_rx.await.map_err(|e| { + tracing::error!("Failed to get headers: {:?}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to get headers".to_string(), + error_type: "headers".to_string(), + }), ) - }; + })?; + if x_compute_type.is_none() { + x_compute_type = headers + .get("x-compute-type") + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + + x_accel_buffering = headers + .get("x-accel-buffering") + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + } + x_compute_characters += headers + .get("x-compute-characters") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse().ok()) + .unwrap_or(0); + } - let (headers, response_stream) = generate_stream_internal( - infer, - compute_type, - Json(generate_request), - on_message_callback, - ) - .await; + let mut headers = HeaderMap::new(); + if let Some(x_compute_type) = x_compute_type { + headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + } + headers.insert("x-compute-characters", x_compute_characters.into()); + if let Some(x_accel_buffering) = x_accel_buffering { + headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + } - let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + // now sink the sse streams into a single stream and remove the ones that are done + let stream: AsyncStream, _> = async_stream::stream! { + loop { + let mut i = 0; + while i < all_rxs.len() { + let rx = &mut all_rxs[i]; + select! { + Some(event) = rx.recv() => { + yield event; + } + else => { + all_rxs.remove(i); + continue; // skip the increment to handle the next element at the same index + } + } + i += 1; // only increment when no element was removed + } + + if all_rxs.is_empty() { + break; + } + } + }; + + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - let (headers, Json(generation)) = generate( - Extension(infer), - Extension(compute_type), - Json(generate_request), - ) - .await?; - let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - let details = generation.details.ok_or(( - // this should never happen but handle if details are missing unexpectedly - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "No details in generation".to_string(), - error_type: "no details".to_string(), - }), - ))?; + let responses = FuturesUnordered::new(); + for (index, generate_request) in generate_requests.into_iter().enumerate() { + let infer_clone = infer.clone(); + let compute_type_clone = compute_type.clone(); + let span_clone = span.clone(); + let response_future = async move { + let result = generate_internal( + Extension(infer_clone), + compute_type_clone, + Json(generate_request), + span_clone, + ) + .await; + result.map(|(headers, generation)| (index, headers, generation)) + }; + responses.push(response_future); + } + let generate_responses = responses.try_collect::>().await?; + + let mut prompt_tokens = 0u32; + let mut completion_tokens = 0u32; + let mut total_tokens = 0u32; + + let mut x_compute_time = 0u32; + let mut x_total_time = 0u32; + let mut x_validation_time = 0u32; + let mut x_queue_time = 0u32; + let mut x_inference_time = 0u32; + let mut x_time_per_token = 0u32; + let mut x_prompt_tokens = 0u32; + let mut x_generated_tokens = 0u32; + + let choices = generate_responses + .into_iter() + .map(|(index, headers, Json(generation))| { + let details = generation.details.ok_or(( + // this should never happen but handle if details are missing unexpectedly + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "No details in generation".to_string(), + error_type: "no details".to_string(), + }), + ))?; + + if x_compute_type.is_none() { + x_compute_type = headers + .get("x-compute-type") + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + } + + // accumulate headers and usage from each response + x_compute_time += headers + .get("x-compute-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_compute_characters += headers + .get("x-compute-characters") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_total_time += headers + .get("x-total-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_validation_time += headers + .get("x-validation-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_queue_time += headers + .get("x-queue-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_inference_time += headers + .get("x-inference-time") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_time_per_token += headers + .get("x-time-per-token") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_prompt_tokens += headers + .get("x-prompt-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + x_generated_tokens += headers + .get("x-generated-tokens") + .and_then(|v| v.to_str().ok()?.parse().ok()) + .unwrap_or(0); + + prompt_tokens += details.prefill.len() as u32; + completion_tokens += details.generated_tokens; + total_tokens += details.prefill.len() as u32 + details.generated_tokens; + + Ok(CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: generation.generated_text, + }) + }) + .collect::, _>>() + .map_err(|(status, Json(err))| (status, Json(err)))?; let response = Completion { id: "".to_string(), @@ -707,19 +920,30 @@ async fn completions( info.version, info.docker_label.unwrap_or("native") ), - choices: vec![CompletionComplete { - finish_reason: details.finish_reason.to_string(), - index: 0, - logprobs: None, - text: generation.generated_text, - }], + choices, usage: Usage { - prompt_tokens: details.prefill.len() as u32, - completion_tokens: details.generated_tokens, - total_tokens: details.prefill.len() as u32 + details.generated_tokens, + prompt_tokens, + completion_tokens, + total_tokens, }, }; + // headers similar to `generate` but aggregated + let mut headers = HeaderMap::new(); + if let Some(x_compute_type) = x_compute_type { + headers.insert("x-compute-type", x_compute_type.parse().unwrap()); + } + headers.insert("x-compute-characters", x_compute_characters.into()); + headers.insert("x-total-time", x_total_time.into()); + headers.insert("x-validation-time", x_validation_time.into()); + headers.insert("x-queue-time", x_queue_time.into()); + headers.insert("x-inference-time", x_inference_time.into()); + headers.insert("x-time-per-token", x_time_per_token.into()); + headers.insert("x-prompt-tokens", x_prompt_tokens.into()); + headers.insert("x-generated-tokens", x_generated_tokens.into()); + if let Some(x_accel_buffering) = x_accel_buffering { + headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap()); + } Ok((headers, Json(response)).into_response()) } } @@ -764,6 +988,7 @@ async fn chat_completions( Extension(info): Extension, Json(req): Json, ) -> Result)> { + let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); let ChatRequest { @@ -901,17 +1126,14 @@ async fn chat_completions( compute_type, Json(generate_request), on_message_callback, + span, ) .await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - let (headers, Json(generation)) = generate( - Extension(infer), - Extension(compute_type), - Json(generate_request), - ) - .await?; + let (headers, Json(generation)) = + generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; let current_time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -1008,6 +1230,7 @@ async fn vertex_compatibility( Extension(compute_type): Extension, Json(req): Json, ) -> Result)> { + let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); // check that theres at least one instance @@ -1039,10 +1262,11 @@ async fn vertex_compatibility( }; async { - generate( + generate_internal( Extension(infer.clone()), - Extension(compute_type.clone()), + compute_type.clone(), Json(generate_request), + span.clone(), ) .await .map(|(_, Json(generation))| generation.generated_text) @@ -1154,6 +1378,7 @@ pub async fn run( tokenizer_config: HubTokenizerConfig, messages_api_enabled: bool, grammar_support: bool, + max_client_batch_size: usize, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1330,6 +1555,7 @@ pub async fn run( max_waiting_tokens, max_batch_size, validation_workers, + max_client_batch_size, version: env!("CARGO_PKG_VERSION"), sha: option_env!("VERGEN_GIT_SHA"), docker_label: option_env!("DOCKER_LABEL"),