Skip to content

Commit

Permalink
🔊 Add TGIS response logs (#15)
Browse files Browse the repository at this point in the history
This PR updates our grpc_server to add TGIS-style logs similar to
https://github.com/IBM/text-generation-inference/blob/main/router/src/grpc_server.rs#L504-L512

This also disables the vllm per-request logging so that we don't
double-log each request

The timing info collected here is pretty rough, it doesn't plumb into
the LLMEngine, it just times the generators to get the total time spent
in the engine. We could do better, but this is a start.

Example logs:

```
INFO 04-09 21:51:01 logs.py:43] generate_stream{input=[b'This is the story of Obama ridin...'] prefix_id= input_chars=[70] params=sampling { } stopping { max_new_tokens: 200 min_new_tokens: 16 } response { } decoding { } tokenization_time=0.45ms queue_and_inference_time=1096.67ms time_per_token=5.48ms total_time=1097.12ms input_toks=16}: Streaming response generated 200 tokens before NOT_FINISHED, output 848 chars: b' California. The story is told i...'
INFO 04-09 21:51:08 logs.py:43] generate{input=[b'Lorem ipsum dolor sit amet, cons...', b'foooood man where is it'] prefix_id= input_chars=[469] params=sampling { } stopping { max_new_tokens: 20 min_new_tokens: 16 } response { } decoding { } tokenization_time=2.03ms queue_and_inference_time=122.23ms time_per_token=6.11ms total_time=124.26ms input_toks=124}: Sub-request 0 from batch of 2 generated 20 tokens before MAX_TOKENS, output 25 chars: b'?\\n\\n<!--\\n<!--\\n<!--\\n<!--\\n<!'
INFO 04-09 21:51:08 logs.py:43] generate{input=[b'Lorem ipsum dolor sit amet, cons...', b'foooood man where is it'] prefix_id= input_chars=[469] params=sampling { } stopping { max_new_tokens: 20 min_new_tokens: 16 } response { } decoding { } tokenization_time=2.07ms queue_and_inference_time=122.22ms time_per_token=6.11ms total_time=124.29ms input_toks=7}: Sub-request 1 from batch of 2 generated 20 tokens before MAX_TOKENS, output 70 chars: b"?\\nI don't know.\\nI don't know.\\nI ..."
```

---------

Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Apr 11, 2024
1 parent 4977313 commit 4ea1722
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 14 deletions.
100 changes: 86 additions & 14 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import dataclasses
import inspect
import time
import uuid
Expand Down Expand Up @@ -33,12 +34,23 @@
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.logger import init_logger
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
TypicalLogitsWarperWrapper)
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

logger = init_logger(__name__)

@dataclasses.dataclass
class Times:
"""Container tracking times (in seconds) when requests start and finish """
# When control enters Generate or GenerateStream
request_start: float
# When the request is sent to the vLLM engine
engine_start: float = 0
# When the stream from the vLLM engine closes
end: float = 0


def with_default(value: Any, default: Any) -> Any:
return value if value else default
Expand Down Expand Up @@ -99,6 +111,7 @@ async def _post_init(self):
@log_rpc_handler_errors
async def Generate(self, request: BatchedGenerationRequest,
context: ServicerContext) -> BatchedGenerationResponse:
start_time = time.time()
request_id = self.request_id(context)
sampling_params, deadline = await self._validate_and_convert_params(
request.params, context)
Expand All @@ -107,16 +120,23 @@ async def Generate(self, request: BatchedGenerationRequest,
request_count = len(request.requests)

generators = []
timing_infos = []
max_is_token_limit = [False] * request_count
for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i]\
= await self._validate_prompt_and_tokenize(
sampling_params, truncate_input_tokens, req.text, context)
timing_info = Times(request_start=start_time)
timing_infos.append(timing_info)
generators.append(
self.engine.generate(None,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids))
self.timed_generator(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
self.engine.generate(prompt=req.text,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}",
prompt_token_ids=input_ids),
timing_info))

# TODO handle cancellation
result_generator: AsyncIterator[Tuple[
Expand All @@ -140,21 +160,28 @@ async def Generate(self, request: BatchedGenerationRequest,
break

for i, res in enumerate(responses):
# Text prompt is not returned if only token_ids are passed
res.prompt = request.requests[i].text
response = self._convert_output(res.outputs[0], resp_options,
max_is_token_limit[i],
time_limit_reached)
responses[i] = self._convert_input_details(res, resp_options,
response = self._convert_input_details(res, resp_options,
sampling_params,
response)
if request_count == 1:
kind_log = "Request"
else:
kind_log = f"Sub-request {i} from batch of {request_count}"

self._log_unary_response(request=request, response=response,
times=timing_infos[i], kind_log=kind_log)
responses[i] = response

return BatchedGenerationResponse(responses=responses)

@log_rpc_handler_errors
async def GenerateStream(
self, request: SingleGenerationRequest,
context: ServicerContext) -> AsyncIterator[GenerationResponse]:
timing_info = Times(request_start=time.time())
request_id = self.request_id(context)
sampling_params, deadline = await self._validate_and_convert_params(
request.params, context)
Expand All @@ -165,24 +192,29 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text,
context)

result_generator = self.engine.generate(
prompt=None,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=input_ids,
result_generator = self.timed_generator(
self.engine.generate(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
prompt=request.request.text,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=input_ids,
),
timing_info
)

resp_options = request.params.response

first = True
first_response = None
last_output_length = 0
last_token_count = 0
time_limit_reached = False
full_output = ""
#TODO handle cancellation
async for result in result_generator:
if first:
# Text prompt is not returned if only token_ids are passed
result.prompt = request.request.text
first_response = self._convert_input_details(
result, resp_options, sampling_params,
GenerationResponse())
Expand All @@ -204,6 +236,17 @@ async def GenerateStream(

last_output_length = len(output.text)
last_token_count = len(output.token_ids)
# Save full output for logging
full_output = output.text

# Edit up the first_response for logging purposes only
if first_response is None:
# We didn't output anything!
return
first_response.text = full_output
first_response.generated_token_count = last_token_count
self._log_streaming_response(request=request, response=first_response,
times=timing_info)

def _convert_input_details(
self, result: RequestOutput, resp_options: ResponseOptions,
Expand Down Expand Up @@ -482,6 +525,35 @@ async def _validate_prompt_and_tokenize(

return input_ids, max_is_token_limit

@staticmethod
def _log_unary_response(request: BatchedGenerationRequest,
response: GenerationResponse, times: Times,
kind_log: str):
logs.log_response(inputs=[r.text for r in request.requests],
response=response, params=request.params,
prefix_id=request.prefix_id, times=times,
kind_log=kind_log, method_str="generate",
logger=logger)

@staticmethod
def _log_streaming_response(request: SingleGenerationRequest,
response: GenerationResponse, times: Times):
logs.log_response(inputs=[request.request.text], response=response,
params=request.params, prefix_id=request.prefix_id,
times=times, kind_log="Streaming response",
method_str="generate_stream", logger=logger)


@staticmethod
async def timed_generator(generator: AsyncIterator[RequestOutput],
times: Times) -> AsyncIterator[RequestOutput]:
"""Injects some timing data around each result generator from the
LLMEngine"""
times.engine_start = time.time()
async for val in generator:
yield val
times.end = time.time()

@log_rpc_handler_errors
async def Tokenize(self, request: BatchedTokenizeRequest,
context: ServicerContext) -> BatchedTokenizeResponse:
Expand Down
4 changes: 4 additions & 0 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,9 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace:
if args.max_logprobs < MAX_TOP_N_TOKENS + 1:
logger.info("Setting max_logprobs to %d", MAX_TOP_N_TOKENS + 1)
args.max_logprobs = MAX_TOP_N_TOKENS + 1
# Turn off vLLM per-request logging because the TGIS server logs each
# response
if not args.disable_log_requests:
args.disable_log_requests = True

return args
62 changes: 62 additions & 0 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Some methods for producing logs similar to TGIS"""
import logging
from typing import List

from google.protobuf import text_format

from vllm.entrypoints.grpc.pb.generation_pb2 import (GenerationResponse,
Parameters, StopReason)


def log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse, times, kind_log: str,
method_str: str, logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
# This time contains both request validation and tokenization
tokenization_time = times.engine_start - times.request_start
llm_engine_time = times.end - times.engine_start
time_per_token = _safe_div(llm_engine_time, response.generated_token_count)
total_time = times.end - times.request_start
output_len = len(response.text)
short_output = _truncate(response.text, 32)
short_input = [_truncate(input_, 32) for input_ in inputs]
input_chars = sum(len(input_) for input_ in inputs)

paramstr = text_format.MessageToString(params, as_one_line=True)
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={paramstr} "
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
f"queue_and_inference_time={llm_engine_time * 1e3:.2f}ms "
f"time_per_token={time_per_token * 1e3:.2f}ms "
f"total_time={total_time * 1e3:.2f}ms "
f"input_toks={response.input_token_count}}}")
stop_reason_str = StopReason.Name(response.stop_reason)

if response.stop_reason == StopReason.ERROR:
level = logging.ERROR
elif response.stop_reason in {
StopReason.CANCELLED, StopReason.TOKEN_LIMIT
}:
level = logging.WARN
else:
level = logging.INFO
logger.log(
level, f"{span_str}: {kind_log} generated "
f"{response.generated_token_count} tokens before "
f"{stop_reason_str}, output {output_len} chars: "
f"{short_output}")


def _truncate(text: str, len_: int) -> bytes:
"""Truncates a string and escapes control characters"""
text = f"{text:.{len_}}..." if len(text) > len_ else text
return text.encode("unicode_escape")


def _safe_div(a: float, b: float, *, default: float = 0.0) -> float:
"""Simple safe division with a default answer for divide-by-zero.
"""
try:
return a / b
except ZeroDivisionError:
return default

0 comments on commit 4ea1722

Please sign in to comment.