From 6084d417f89782ab1ad8e0ccf2d4d63c32dff6b3 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 8 May 2024 16:00:13 -0600 Subject: [PATCH] format: make mypy happy (#24) `format.sh` now has mypy checks after pulling in upstream changes. This PR makes the mypy suggested modifications to our code. --------- Signed-off-by: Travis Johnson --- vllm/entrypoints/grpc/grpc_server.py | 23 ++++++++++++----------- vllm/entrypoints/openai/api_server.py | 1 - vllm/tgis_utils/args.py | 8 ++++---- vllm/tgis_utils/logs.py | 9 ++++----- 4 files changed, 20 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index ebec0bbcf03f6..58281cc65c5d5 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -14,7 +14,7 @@ from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput, SamplingParams) from vllm.config import ModelConfig -from vllm.entrypoints.grpc.pb import generation_pb2_grpc +from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore # yapf: disable from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest, BatchedGenerationResponse, @@ -54,7 +54,7 @@ async def _handle_exception(e: Exception, func, *args, **kwargs): if not isinstance(e, AbortError): if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check context = kwargs.get("context", None) or args[-1] - logger.exception(f"{func.__name__} caused GPU OOM error") + logger.exception("%s caused GPU OOM error", func.__name__) service_metrics.count_request_failure(FailureReasonLabel.OOM) await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e)) else: @@ -62,7 +62,7 @@ async def _handle_exception(e: Exception, func, *args, **kwargs): service_metrics.count_request_failure(FailureReasonLabel.GENERATE) else: service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN) - logger.exception(f"{func.__name__} failed") + logger.exception("%s failed", func.__name__) raise e @@ -298,7 +298,7 @@ def _convert_output(self, text=output.text[text_start_offset:], generated_token_count=len(output.token_ids), stop_reason=stop_reason, - stop_sequence=stop_sequence, + stop_sequence=stop_sequence if stop_sequence else '', ) if resp_options.generated_tokens: @@ -416,7 +416,8 @@ async def _validate_and_convert_params( @staticmethod def _convert_reason(output: CompletionOutput, max_is_token_limit: bool, - time_limit_reached: bool) -> Tuple['StopReason', str]: + time_limit_reached: bool + ) -> Tuple[StopReason.ValueType, Optional[str]]: finish_reason = output.finish_reason stop_sequence = None if finish_reason is None: @@ -436,20 +437,20 @@ def _convert_reason(output: CompletionOutput, max_is_token_limit: bool, stop_sequence = stop_str_or_tok else: logger.warning( - f"Unexpected stop_reason type: {type(stop_str_or_tok)}" + "Unexpected stop_reason type: %s", type(stop_str_or_tok) ) elif finish_reason == "abort": stop_reason = StopReason.CANCELLED else: - logger.warning(f"Unrecognized finish_reason: {finish_reason}") + logger.warning("Unrecognized finish_reason: %s", finish_reason) stop_reason = StopReason.CANCELLED return stop_reason, stop_sequence def _convert_tokens( self, - token_ids: list[int], - logprobs_list: Optional[list[Dict[int, Logprob]]], + token_ids: List[int], + logprobs_list: Optional[List[Dict[int, Logprob]]], include_logprobs: bool, include_ranks: bool, top_n_tokens: int, @@ -502,7 +503,7 @@ async def _validate_prompt_and_tokenize( # "max_length": truncate_input_tokens} \ # if truncate_input_tokens is not None else { # "truncation": True, "max_length": max_model_len + 1} - tokenize_kwargs = {} + tokenize_kwargs: Dict[str, Any] = {} input_ids = await self.tokenizer_group.encode_async( prompt, **tokenize_kwargs) @@ -664,6 +665,6 @@ async def start_grpc_server(engine: AsyncLLMEngine, server.add_insecure_port(listen_on) await server.start() - logger.info(f"gRPC Server started at {listen_on}") + logger.info("gRPC Server started at %s", listen_on) return server diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3b9392bc73ccd..5d1baf42bd2d9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,4 +1,3 @@ -import argparse import asyncio import importlib import inspect diff --git a/vllm/tgis_utils/args.py b/vllm/tgis_utils/args.py index 8c5c558746432..33480e1323db5 100644 --- a/vllm/tgis_utils/args.py +++ b/vllm/tgis_utils/args.py @@ -129,10 +129,10 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace: if args.max_batch_size is not None: # Existing MAX_BATCH_SIZE settings in TGIS configs may not necessarily # be best for vLLM so we'll just log a warning for now - logger.warn( - f"max_batch_size is set to {args.max_batch_size} but will be " - f"ignored for now. max_num_seqs can be used if this is still " - f"needed.") + logger.warning( + "max_batch_size is set to %d but will be ignored for now." + "max_num_seqs can be used if this is still needed.", + args.max_batch_size) if args.tls_cert_path: args.ssl_certfile = args.tls_cert_path diff --git a/vllm/tgis_utils/logs.py b/vllm/tgis_utils/logs.py index 9b3f41bef77aa..6cf81accac508 100644 --- a/vllm/tgis_utils/logs.py +++ b/vllm/tgis_utils/logs.py @@ -45,11 +45,10 @@ def log_response(inputs: List[str], params: Parameters, prefix_id: str, level = logging.WARN else: level = logging.INFO - logger.log( - level, f"{span_str}: {kind_log} generated " - f"{response.generated_token_count} tokens before " - f"{stop_reason_str}, output {output_len} chars: " - f"{short_output}") + logger.log(level, + "%s: %s generated %d tokens before %s, output %d chars: %s", + span_str, kind_log, response.generated_token_count, + stop_reason_str, output_len, short_output) def _truncate(text: str, len_: int) -> bytes: