Skip to content

Commit

Permalink
support deploy info
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Aug 4, 2024
1 parent fec2685 commit 0d20b2d
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 21 deletions.
2 changes: 2 additions & 0 deletions docs/source/LLM/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ deploy参数继承了infer参数, 除此之外增加了以下参数:
- `--api_key`: 默认为`None`, 即不对请求进行api_key验证.
- `--ssl_keyfile`: 默认为`None`.
- `--ssl_certfile`: 默认为`None`.
- `--verbose`: 是否对请求内容进行打印, 默认为`True`.
- `--log_interval`: 对统计信息进行打印的间隔, 单位为秒. 默认为`0`, 表示不打印统计信息.

## web-ui 参数

Expand Down
2 changes: 2 additions & 0 deletions docs/source_en/LLM/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ deploy parameters inherit from infer parameters, with the following added parame
- `--api_key`: The default is `None`, meaning that the request will not be subjected to api_key verification.
- `--ssl_keyfile`: Default is `None`.
- `--ssl_certfile`: Default is `None`.
- `--verbose`: Whether to print the request content. Defaults to `True`.
- `--log_interval`: Interval for printing statistical information, in seconds. Defaults to 0, meaning no statistical information will be printed.

## web-ui Parameters

Expand Down
70 changes: 64 additions & 6 deletions swift/llm/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,47 @@

logger = get_logger()

app = FastAPI()
global_stats = {}
default_global_stats = {
'num_prompt_tokens': 0,
'num_generated_tokens': 0,
'num_samples': 0,
'runtime': 0.,
'samples/s': 0.,
'tokens/s': 0.
}


async def _log_stats_hook(log_interval: int):
global global_stats
while True:
global_stats = default_global_stats.copy()
t = time.perf_counter()
await asyncio.sleep(log_interval)
runtime = time.perf_counter() - t
global_stats['runtime'] = runtime
global_stats['samples/s'] = global_stats['num_samples'] / runtime
global_stats['tokens/s'] = global_stats['num_generated_tokens'] / runtime
for k, v in global_stats.items():
global_stats[k] = round(v, 8)
logger.info(global_stats)


def _update_stats(response) -> None:
usage_info = response.usage
global_stats['num_prompt_tokens'] += usage_info.prompt_tokens
global_stats['num_generated_tokens'] += usage_info.completion_tokens
global_stats['num_samples'] += 1


async def lifespan(app: FastAPI):
global _args
if _args.log_interval > 0:
asyncio.create_task(_log_stats_hook(_args.log_interval))
yield


app = FastAPI(lifespan=lifespan)
_args: Optional[DeployArguments] = None
model = None
llm_engine = None
Expand Down Expand Up @@ -216,7 +256,8 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
generation_config.stop.append(token_str)
request_info['generation_config'] = generation_config
request_info.update({'seed': request.seed, 'stream': request.stream})
logger.info(request_info)
if _args.verbose:
logger.info(request_info)

generate_kwargs = {}
if _args.vllm_enable_lora and request.model != _args.model_type:
Expand Down Expand Up @@ -284,11 +325,14 @@ async def _generate_full():
choices.append(choice)
response = CompletionResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
if _args.log_interval > 0:
_update_stats(response)
return response

async def _generate_stream():
print_idx_list = [[0] for _ in range(request.n)]
total_res = ['' for _ in range(request.n)]
response = None
async for result in result_generator:
num_prompt_tokens = len(result.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in result.outputs)
Expand Down Expand Up @@ -328,6 +372,8 @@ async def _generate_stream():
response = CompletionStreamResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
yield f'data:{json.dumps(asdict(response), ensure_ascii=False)}\n\n'
if _args.log_interval > 0 and response is not None:
_update_stats(response)
yield 'data:[DONE]\n\n'

if request.stream:
Expand Down Expand Up @@ -368,7 +414,8 @@ async def inference_lmdeploy_async(request: Union[ChatCompletionRequest, Complet
generation_config = LmdeployGenerationConfig(**kwargs)
request_info['generation_config'] = generation_config
request_info.update({'seed': request.seed, 'stream': request.stream})
logger.info(request_info)
if _args.verbose:
logger.info(request_info)

generator = await llm_engine.get_generator(False, created_time)
images = inputs.pop('images', None) or []
Expand Down Expand Up @@ -418,6 +465,8 @@ async def _generate_full():
)]
response = CompletionResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
if _args.log_interval > 0:
_update_stats(response)
return response

async def _generate_stream():
Expand All @@ -428,6 +477,7 @@ async def _generate_stream():
async_iter = generator.async_stream_infer(
session_id=created_time, **inputs, stream_output=True, gen_config=generation_config).__aiter__()
is_finished = False
response = None
while not is_finished:
try:
output = await async_iter.__anext__()
Expand Down Expand Up @@ -469,6 +519,8 @@ async def _generate_stream():
response = CompletionStreamResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
yield f'data:{json.dumps(asdict(response), ensure_ascii=False)}\n\n'
if _args.log_interval > 0 and response is not None:
_update_stats(response)
yield 'data:[DONE]\n\n'

if request.stream:
Expand Down Expand Up @@ -540,7 +592,8 @@ async def inference_pt_async(request: Union[ChatCompletionRequest, CompletionReq
request_info['generation_config'] = generation_config
stop = (_args.stop_words or []) + (getattr(request, 'stop') or [])
request_info.update({'seed': request.seed, 'stop': stop, 'stream': request.stream})
logger.info(request_info)
if _args.verbose:
logger.info(request_info)

adapter_kwargs = {}
if _args.lora_request_list is not None:
Expand Down Expand Up @@ -597,6 +650,8 @@ async def _generate_full():
)]
response = CompletionResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
if _args.log_interval > 0:
_update_stats(response)
return response

def _generate_stream():
Expand All @@ -613,6 +668,7 @@ def _generate_stream():
print_idx = 0
response = ''
is_finished = False
resp = None
while not is_finished:
try:
response, _ = next(gen)
Expand Down Expand Up @@ -651,6 +707,8 @@ def _generate_stream():
resp = CompletionStreamResponse(
model=request.model, choices=choices, usage=usage_info, id=request_id, created=created_time)
yield f'data:{json.dumps(asdict(resp), ensure_ascii=False)}\n\n'
if _args.log_interval > 0 and resp is not None:
_update_stats(resp)
yield 'data:[DONE]\n\n'

if request.stream:
Expand All @@ -660,7 +718,7 @@ def _generate_stream():


@app.post('/v1/chat/completions')
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request) -> ChatCompletionResponse:
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
global _args
assert _args is not None
if request.stop is None:
Expand All @@ -674,7 +732,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re


@app.post('/v1/completions')
async def create_completion(request: CompletionRequest, raw_request: Request) -> CompletionResponse:
async def create_completion(request: CompletionRequest, raw_request: Request):
global _args
assert _args is not None
if request.stop is None:
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .argument import (AppUIArguments, DeployArguments, EvalArguments, ExportArguments, InferArguments, PtArguments,
RLHFArguments, RomeArguments, SftArguments, WebuiArguments, is_adapter, swift_to_peft_format)
from .client_utils import (compat_openai, convert_to_base64, decode_base64, get_model_list_client, inference_client,
inference_client_async)
from .client_utils import (compat_openai, convert_to_base64, decode_base64, get_model_list_client,
get_model_list_client_async, inference_client, inference_client_async)
from .dataset import (DATASET_MAPPING, DatasetName, HfDataset, get_dataset, get_dataset_from_repo,
load_dataset_from_local, load_ms_dataset, register_dataset, register_dataset_info,
register_local_dataset, sample_dataset)
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,8 @@ class DeployArguments(InferArguments):
ssl_certfile: Optional[str] = None

owned_by: str = 'swift'
verbose: bool = True # Whether to log request_info
log_interval: int = 0 # Interval for printing global statistics

def __post_init__(self):
super().__post_init__()
Expand Down
55 changes: 42 additions & 13 deletions swift/llm/utils/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def get_model_list_client(host: str = '127.0.0.1', port: str = '8000', api_key:
return from_dict(ModelList, resp_obj)


async def get_model_list_client_async(host: str = '127.0.0.1',
port: str = '8000',
api_key: str = 'EMPTY',
**kwargs) -> ModelList:
url = kwargs.pop('url', None)
if url is None:
url = f'http://{host}:{port}/v1'
url = url.rstrip('/')
url = f'{url}/models'
async with aiohttp.ClientSession() as session:
async with session.get(url, **_get_request_kwargs(api_key)) as resp:
resp_obj = await resp.json()
return from_dict(ModelList, resp_obj)


def _parse_stream_data(data: bytes) -> Optional[str]:
data = data.decode(encoding='utf-8')
data = data.strip()
Expand Down Expand Up @@ -177,27 +192,27 @@ def _pre_inference_client(model_type: str,
tools: Optional[List[Dict[str, Union[str, Dict]]]] = None,
tool_choice: Optional[Union[str, Dict]] = 'auto',
*,
model_list: Optional[ModelList] = None,
is_chat_request: Optional[bool] = None,
is_multimodal: Optional[bool] = None,
request_config: Optional[XRequestConfig] = None,
host: str = '127.0.0.1',
port: str = '8000',
api_key: str = 'EMPTY',
**kwargs) -> Tuple[str, Dict[str, Any], bool]:
if images is None:
images = []
model_list = get_model_list_client(host, port, **kwargs)
for model in model_list.data:
if model_type == model.id:
_is_chat = model.is_chat
is_multimodal = model.is_multimodal
break
else:
raise ValueError(f'model_type: {model_type}, model_list: {[model.id for model in model_list.data]}')

if is_chat_request is None:
is_chat_request = _is_chat
assert is_chat_request is not None, (
'Please set the `is_chat_request` parameter to indicate whether the model is a chat model.')
if model_list is not None:
for model in model_list.data:
if model_type == model.id:
if is_chat_request is None:
is_chat_request = model.is_chat
if is_multimodal is None:
is_multimodal = model.is_multimodal
break
else:
raise ValueError(f'model_type: {model_type}, model_list: {[model.id for model in model_list.data]}')
assert is_chat_request is not None and is_multimodal is not None
data = {k: v for k, v in request_config.__dict__.items() if not k.startswith('__')}
url = kwargs.pop('url', None)
if url is None:
Expand Down Expand Up @@ -238,6 +253,7 @@ def inference_client(
tool_choice: Optional[Union[str, Dict]] = 'auto',
*,
is_chat_request: Optional[bool] = None,
is_multimodal: Optional[bool] = None,
request_config: Optional[XRequestConfig] = None,
host: str = '127.0.0.1',
port: str = '8000',
Expand All @@ -247,6 +263,10 @@ def inference_client(
Iterator[CompletionStreamResponse]]:
if request_config is None:
request_config = XRequestConfig()
model_list = None
if is_chat_request is None or is_multimodal is None:
model_list = get_model_list_client(host, port, **kwargs)

url, data, is_chat_request = _pre_inference_client(
model_type,
query,
Expand All @@ -255,7 +275,9 @@ def inference_client(
images,
tools,
tool_choice,
model_list=model_list,
is_chat_request=is_chat_request,
is_multimodal=is_multimodal,
request_config=request_config,
host=host,
port=port,
Expand Down Expand Up @@ -302,6 +324,7 @@ async def inference_client_async(
tool_choice: Optional[Union[str, Dict]] = 'auto',
*,
is_chat_request: Optional[bool] = None,
is_multimodal: Optional[bool] = None,
request_config: Optional[XRequestConfig] = None,
host: str = '127.0.0.1',
port: str = '8000',
Expand All @@ -311,6 +334,10 @@ async def inference_client_async(
AsyncIterator[CompletionStreamResponse]]:
if request_config is None:
request_config = XRequestConfig()
model_list = None
if is_chat_request is None or is_multimodal is None:
model_list = await get_model_list_client_async(host, port, **kwargs)

url, data, is_chat_request = _pre_inference_client(
model_type,
query,
Expand All @@ -319,7 +346,9 @@ async def inference_client_async(
images,
tools,
tool_choice,
model_list=model_list,
is_chat_request=is_chat_request,
is_multimodal=is_multimodal,
request_config=request_config,
host=host,
port=port,
Expand Down

0 comments on commit 0d20b2d

Please sign in to comment.