Skip to content

Commit

Permalink
TGIS metrics (#18)
Browse files Browse the repository at this point in the history
This PR implements a subset of the metrics from the TGIS image. I tried
to make sure that everything from our current ops dashboard is
supported. These are:

- tgi_tokenize_request_tokens 
- tgi_tokenize_request_input_count 
- tgi_request_input_count 
- tgi_request_failure 
- tgi_request_queue_duration 
- tgi_queue_size 
- tgi_batch_current_size 
- tgi_batch_inference_duration 
- tgi_request_input_length 
- tgi_request_generated_tokens

---------

Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Apr 18, 2024
1 parent 1613074 commit 8c548e4
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 60 deletions.
114 changes: 61 additions & 53 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import dataclasses
import inspect
import time
import uuid
Expand Down Expand Up @@ -37,19 +36,12 @@
from vllm.tgis_utils import logs
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
TypicalLogitsWarperWrapper)
from vllm.tgis_utils.metrics import (FailureReasonLabel, ServiceMetrics,
TGISStatLogger)
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

logger = init_logger(__name__)

@dataclasses.dataclass
class Times:
"""Container tracking times (in seconds) when requests start and finish """
# When control enters Generate or GenerateStream
request_start: float
# When the request is sent to the vLLM engine
engine_start: float = 0
# When the stream from the vLLM engine closes
end: float = 0
service_metrics = ServiceMetrics()


def with_default(value: Any, default: Any) -> Any:
Expand All @@ -63,7 +55,13 @@ async def _handle_exception(e: Exception, func, *args, **kwargs):
if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
context = kwargs.get("context", None) or args[-1]
logger.exception(f"{func.__name__} caused GPU OOM error")
service_metrics.count_request_failure(FailureReasonLabel.OOM)
await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e))
else:
if "generate" in func.__name__.lower():
service_metrics.count_request_failure(FailureReasonLabel.GENERATE)
else:
service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN)
logger.exception(f"{func.__name__} failed")
raise e

Expand Down Expand Up @@ -108,10 +106,20 @@ async def _post_init(self):
self.tokenizer_group = await self.engine.get_tokenizer_group()
self.tokenizer = await self.engine.get_tokenizer()

# Swap in the special TGIS stats logger
vllm_stat_logger = self.engine.engine.stat_logger
tgis_stats_logger = TGISStatLogger(
vllm_stat_logger=vllm_stat_logger,
max_sequence_len=self.config.max_model_len)
# 🌶️🌶️🌶️ sneaky sneak
self.engine.engine.stat_logger = tgis_stats_logger


@log_rpc_handler_errors
async def Generate(self, request: BatchedGenerationRequest,
context: ServicerContext) -> BatchedGenerationResponse:
start_time = time.time()
service_metrics.count_generate_request(len(request.requests))
request_id = self.request_id(context)
sampling_params, deadline = await self._validate_and_convert_params(
request.params, context)
Expand All @@ -120,23 +128,19 @@ async def Generate(self, request: BatchedGenerationRequest,
request_count = len(request.requests)

generators = []
timing_infos = []
max_is_token_limit = [False] * request_count
for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i]\
= await self._validate_prompt_and_tokenize(
sampling_params, truncate_input_tokens, req.text, context)
timing_info = Times(request_start=start_time)
timing_infos.append(timing_info)
generators.append(
self.timed_generator(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
self.engine.generate(prompt=req.text,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}",
prompt_token_ids=input_ids),
timing_info))
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
self.engine.generate(prompt=req.text,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}",
prompt_token_ids=input_ids),
)

# TODO handle cancellation
result_generator: AsyncIterator[Tuple[
Expand All @@ -151,6 +155,7 @@ async def Generate(self, request: BatchedGenerationRequest,
# await self.engine.abort(f"{request_id}-{i}")
# return self.create_error_response("Client disconnected")
responses[i] = res
service_metrics.observe_queue_time(res)

if deadline is not None and time.time(
) >= deadline and None not in responses:
Expand All @@ -173,7 +178,8 @@ async def Generate(self, request: BatchedGenerationRequest,
kind_log = f"Sub-request {i} from batch of {request_count}"

self._log_unary_response(request=request, response=response,
times=timing_infos[i], kind_log=kind_log)
start_time=start_time, engine_response=res,
kind_log=kind_log)
responses[i] = response

return BatchedGenerationResponse(responses=responses)
Expand All @@ -182,7 +188,8 @@ async def Generate(self, request: BatchedGenerationRequest,
async def GenerateStream(
self, request: SingleGenerationRequest,
context: ServicerContext) -> AsyncIterator[GenerationResponse]:
timing_info = Times(request_start=time.time())
start_time = time.time()
service_metrics.count_generate_request()
request_id = self.request_id(context)
sampling_params, deadline = await self._validate_and_convert_params(
request.params, context)
Expand All @@ -193,16 +200,13 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text,
context)

result_generator = self.timed_generator(
self.engine.generate(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
prompt=request.request.text,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=input_ids,
),
timing_info
result_generator = self.engine.generate(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
prompt=request.request.text,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=input_ids,
)

resp_options = request.params.response
Expand All @@ -213,9 +217,12 @@ async def GenerateStream(
last_token_count = 0
time_limit_reached = False
full_output = ""
last_engine_response = None
#TODO handle cancellation
async for result in result_generator:
last_engine_response = result
if first:
service_metrics.observe_queue_time(result)
first_response = self._convert_input_details(
result, resp_options, sampling_params,
GenerationResponse())
Expand Down Expand Up @@ -247,7 +254,8 @@ async def GenerateStream(
first_response.text = full_output
first_response.generated_token_count = last_token_count
self._log_streaming_response(request=request, response=first_response,
times=timing_info)
start_time=start_time,
engine_response=last_engine_response)

def _convert_input_details(
self, result: RequestOutput, resp_options: ResponseOptions,
Expand Down Expand Up @@ -314,6 +322,7 @@ async def _validate_and_convert_params(
try:
validate_params(params, self.max_max_new_tokens)
except ValueError as tgis_validation_error:
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
await context.abort(StatusCode.INVALID_ARGUMENT,
str(tgis_validation_error))

Expand Down Expand Up @@ -396,6 +405,7 @@ async def _validate_and_convert_params(
except ValueError as vllm_validation_error:
# There may be validation cases caught by vLLM that are not covered
# by the TGIS api validation
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
await context.abort(StatusCode.INVALID_ARGUMENT,
str(vllm_validation_error))

Expand Down Expand Up @@ -528,36 +538,32 @@ async def _validate_prompt_and_tokenize(

@staticmethod
def _log_unary_response(request: BatchedGenerationRequest,
response: GenerationResponse, times: Times,
kind_log: str):
response: GenerationResponse,
engine_response: RequestOutput,
start_time: float, kind_log: str):
logs.log_response(inputs=[r.text for r in request.requests],
response=response, params=request.params,
prefix_id=request.prefix_id, times=times,
kind_log=kind_log, method_str="generate",
logger=logger)
prefix_id=request.prefix_id,
engine_response=engine_response,
start_time=start_time, kind_log=kind_log,
method_str="generate", logger=logger)

@staticmethod
def _log_streaming_response(request: SingleGenerationRequest,
response: GenerationResponse, times: Times):
response: GenerationResponse,
engine_response: RequestOutput,
start_time: float):
logs.log_response(inputs=[request.request.text], response=response,
params=request.params, prefix_id=request.prefix_id,
times=times, kind_log="Streaming response",
engine_response=engine_response,
start_time=start_time, kind_log="Streaming response",
method_str="generate_stream", logger=logger)


@staticmethod
async def timed_generator(generator: AsyncIterator[RequestOutput],
times: Times) -> AsyncIterator[RequestOutput]:
"""Injects some timing data around each result generator from the
LLMEngine"""
times.engine_start = time.time()
async for val in generator:
yield val
times.end = time.time()

@log_rpc_handler_errors
async def Tokenize(self, request: BatchedTokenizeRequest,
context: ServicerContext) -> BatchedTokenizeResponse:
service_metrics.observe_tokenization_request(request)
#TODO implement these
if request.return_offsets:
await context.abort(StatusCode.INVALID_ARGUMENT,
Expand All @@ -578,7 +584,9 @@ async def Tokenize(self, request: BatchedTokenizeRequest,
tokens=None if not request.return_tokens else
self.tokenizer.convert_ids_to_tokens(token_ids)))

return BatchedTokenizeResponse(responses=responses)
response = BatchedTokenizeResponse(responses=responses)
service_metrics.observe_tokenization_response(response)
return response

@log_rpc_handler_errors
async def ModelInfo(self, request: ModelInfoRequest,
Expand Down
19 changes: 12 additions & 7 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@

from google.protobuf import text_format

from vllm import RequestOutput
from vllm.entrypoints.grpc.pb.generation_pb2 import (GenerationResponse,
Parameters, StopReason)


def log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse, times, kind_log: str,
method_str: str, logger: logging.Logger):
response: GenerationResponse, engine_response: RequestOutput,
start_time: float, kind_log: str, method_str: str,
logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
# This time contains both request validation and tokenization
tokenization_time = times.engine_start - times.request_start
llm_engine_time = times.end - times.engine_start
time_per_token = _safe_div(llm_engine_time, response.generated_token_count)
total_time = times.end - times.request_start
tokenization_time = engine_response.metrics.arrival_time - start_time
inference_time = (engine_response.metrics.last_token_time -
engine_response.metrics.first_scheduled_time)
queue_time = engine_response.metrics.time_in_queue
time_per_token = _safe_div(inference_time, response.generated_token_count)
total_time = engine_response.metrics.last_token_time - start_time
output_len = len(response.text)
short_output = _truncate(response.text, 32)
short_input = [_truncate(input_, 32) for input_ in inputs]
Expand All @@ -26,7 +30,8 @@ def log_response(inputs: List[str], params: Parameters, prefix_id: str,
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={paramstr} "
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
f"queue_and_inference_time={llm_engine_time * 1e3:.2f}ms "
f"queue_time={queue_time * 1e3:.2f}ms "
f"inference_time={inference_time * 1e3:.2f}ms "
f"time_per_token={time_per_token * 1e3:.2f}ms "
f"total_time={total_time * 1e3:.2f}ms "
f"input_toks={response.input_token_count}}}")
Expand Down
Loading

0 comments on commit 8c548e4

Please sign in to comment.