Skip to content

Commit

Permalink
support max_batch_size (#1599)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Aug 5, 2024
1 parent fe994a5 commit 2f0dceb
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<p align="center">
<img src="https://img.shields.io/badge/python-%E2%89%A53.8-5be.svg">
<img src="https://img.shields.io/badge/pytorch-%E2%89%A51.12%20%7C%20%E2%89%A52.0-orange.svg">
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.9.5-5D91D4.svg"></a>
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.17-5D91D4.svg"></a>
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<p align="center">
<img src="https://img.shields.io/badge/python-%E2%89%A53.8-5be.svg">
<img src="https://img.shields.io/badge/pytorch-%E2%89%A51.12%20%7C%20%E2%89%A52.0-orange.svg">
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.9.5-5D91D4.svg"></a>
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.17-5D91D4.svg"></a>
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ deploy参数继承了infer参数, 除此之外增加了以下参数:
- `--ssl_keyfile`: 默认为`None`.
- `--ssl_certfile`: 默认为`None`.
- `--verbose`: 是否对请求内容进行打印, 默认为`True`.
- `--log_interval`: 对统计信息进行打印的间隔, 单位为秒. 默认为`0`, 表示不打印统计信息.
- `--log_interval`: 对统计信息进行打印的间隔, 单位为秒. 默认为`10`. 如果设置为`0`, 表示不打印统计信息.

## web-ui 参数

Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/LLM/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ deploy parameters inherit from infer parameters, with the following added parame
- `--ssl_keyfile`: Default is `None`.
- `--ssl_certfile`: Default is `None`.
- `--verbose`: Whether to print the request content. Defaults to `True`.
- `--log_interval`: Interval for printing statistical information, in seconds. Defaults to 0, meaning no statistical information will be printed.
- `--log_interval`: The interval for printing statistics, in seconds. Default is `10`. If set to `0`, it means statistics will not be printed.

## web-ui Parameters

Expand Down
4 changes: 1 addition & 3 deletions requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@ aiohttp
attrdict
binpacking
dacite
datasets<2.19
einops
huggingface_hub<0.24
importlib_metadata
jieba
matplotlib
modelscope>=1.14
modelscope[datasets]>=1.17
nltk
numpy<2.0
oss2
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ class DeployArguments(InferArguments):

owned_by: str = 'swift'
verbose: bool = True # Whether to log request_info
log_interval: int = 0 # Interval for printing global statistics
log_interval: int = 10 # Interval for printing global statistics

def __post_init__(self):
super().__post_init__()
Expand Down
44 changes: 38 additions & 6 deletions swift/llm/utils/lmdeploy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def _prepare_lmdeploy_request(lmdeploy_engine: Union[AsyncEngine, VLAsyncEngine]
use_tqdm: bool = False,
**kwargs):
for key in ['num_prompt_tokens', 'num_generated_tokens', 'num_samples']:
generation_info[key] = 0
if key not in generation_info:
generation_info[key] = 0

if hasattr(lmdeploy_engine, 'vl_encoder'):
lmdeploy_engine.vl_encoder._loop_task = None
Expand Down Expand Up @@ -301,6 +302,7 @@ def inference_lmdeploy(lmdeploy_engine: Union[AsyncEngine, VLAsyncEngine],
*,
generation_config: Optional[LmdeployGenerationConfig] = None,
generation_info: Optional[Dict[str, Any]] = None,
max_batch_size: Optional[int] = None,
use_tqdm: bool = False,
verbose: bool = False,
prompt_prefix: str = '[PROMPT]',
Expand All @@ -309,15 +311,45 @@ def inference_lmdeploy(lmdeploy_engine: Union[AsyncEngine, VLAsyncEngine],
if len(request_list) == 0:
return []
runtime = time.perf_counter()

is_multimodal = getattr(lmdeploy_engine, 'is_multimodal', False)
if is_multimodal and max_batch_size is None:
max_batch_size = 512

_inner_call = kwargs.get('_inner_call', False)
if generation_info is None:
generation_info = {}
elif not _inner_call:
generation_info.clear()
if max_batch_size is not None and len(request_list) > max_batch_size:
i = 0
resp_list = []
kwargs['_inner_call'] = True
while i < len(request_list):
resp_list += inference_lmdeploy(
lmdeploy_engine,
template,
request_list[i:i + max_batch_size],
generation_config=generation_config,
generation_info=generation_info,
max_batch_size=max_batch_size,
use_tqdm=use_tqdm,
verbose=verbose,
prompt_prefix=prompt_prefix,
output_prefix=output_prefix,
**kwargs)
i += max_batch_size
runtime = time.perf_counter() - runtime
generation_info['runtime'] = runtime
generation_info['samples/s'] = generation_info['num_samples'] / runtime
generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime
return resp_list

if generation_config is None:
generation_config = getattr(lmdeploy_engine, 'generation_config', LmdeployGenerationConfig())
assert isinstance(generation_config, LmdeployGenerationConfig)
request_list = deepcopy(request_list)
generation_config = deepcopy(generation_config)
if generation_info is None:
generation_info = {}
else:
generation_info.clear()

resp_list, generators = _prepare_lmdeploy_request(
lmdeploy_engine,
Expand Down Expand Up @@ -366,7 +398,7 @@ async def _batch_infer() -> None:
prog_bar.close()
runtime = time.perf_counter() - runtime
generation_info['runtime'] = runtime
generation_info['samples/s'] = len(generators) / runtime
generation_info['samples/s'] = generation_info['num_samples'] / runtime
generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime
return resp_list

Expand Down
48 changes: 41 additions & 7 deletions swift/llm/utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def _prepare_vllm_request(llm_engine: LLMEngine,
use_tqdm: bool = False,
**kwargs) -> Tuple[List[Optional[Dict[str, Any]]], List[Tuple[bool, int]]]:
for key in ['num_prompt_tokens', 'num_generated_tokens', 'num_samples']:
generation_info[key] = 0
if key not in generation_info:
generation_info[key] = 0

template.model = llm_engine
tokenizer = template.tokenizer
Expand Down Expand Up @@ -457,6 +458,7 @@ def inference_vllm(llm_engine: LLMEngine,
*,
generation_config: Optional[VllmGenerationConfig] = None,
generation_info: Optional[Dict[str, Any]] = None,
max_batch_size: Optional[int] = None,
lora_request: Optional['LoRARequest'] = None,
use_tqdm: bool = False,
verbose: bool = False,
Expand All @@ -473,16 +475,48 @@ def inference_vllm(llm_engine: LLMEngine,
if len(request_list) == 0:
return []
runtime = time.perf_counter()

is_multimodal = getattr(llm_engine, 'is_multimodal', False)
if is_multimodal and max_batch_size is None:
max_batch_size = 512

_inner_call = kwargs.get('_inner_call', False)
if generation_info is None:
generation_info = {}
elif not _inner_call:
generation_info.clear()
if max_batch_size is not None and len(request_list) > max_batch_size:
i = 0
resp_list = []
kwargs['_inner_call'] = True
while i < len(request_list):
resp_list += inference_vllm(
llm_engine,
template,
request_list[i:i + max_batch_size],
generation_config=generation_config,
generation_info=generation_info,
max_batch_size=max_batch_size,
lora_request=lora_request,
use_tqdm=use_tqdm,
verbose=verbose,
prompt_prefix=prompt_prefix,
output_prefix=output_prefix,
**kwargs)
i += max_batch_size
runtime = time.perf_counter() - runtime
generation_info['runtime'] = runtime
generation_info['samples/s'] = generation_info['num_samples'] / runtime
generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime
return resp_list

if generation_config is None:
generation_config = getattr(llm_engine, 'generation_config', VllmGenerationConfig())
assert isinstance(generation_config, VllmGenerationConfig)
request_list = deepcopy(request_list)
generation_config = deepcopy(generation_config)
if generation_info is None:
generation_info = {}
else:
generation_info.clear()

old_num_samples = generation_info.get('num_samples', 0)
resp_list, agent_state = _prepare_vllm_request(
llm_engine,
template,
Expand All @@ -496,7 +530,7 @@ def inference_vllm(llm_engine: LLMEngine,
tokenizer = template.tokenizer
if use_tqdm:
assert verbose is False
prog_bar = tqdm(total=generation_info['num_samples'], dynamic_ncols=True, disable=not use_tqdm)
prog_bar = tqdm(total=generation_info['num_samples'] - old_num_samples, dynamic_ncols=True, disable=not use_tqdm)
outputs = []
while llm_engine.has_unfinished_requests():
step_outputs = llm_engine.step()
Expand Down Expand Up @@ -525,7 +559,7 @@ def inference_vllm(llm_engine: LLMEngine,
print(tokenizer.decode(output.outputs[0].token_ids, False))
runtime = time.perf_counter() - runtime
generation_info['runtime'] = runtime
generation_info['samples/s'] = len(outputs) / runtime
generation_info['samples/s'] = generation_info['num_samples'] / runtime
generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime
return resp_list

Expand Down
4 changes: 2 additions & 2 deletions tests/custom/test_lmdeploy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def test_lmdeploy():
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from swift.llm import (ModelType, get_lmdeploy_engine, get_default_template_type, get_template, inference_lmdeploy,
inference_stream_lmdeploy)
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_lmdeploy():
print(generation_info)

# batched
n_batched = 100
n_batched = 1000
request_list = [{'query': '晚上睡不着觉怎么办?'} for i in range(n_batched)]
resp_list = inference_lmdeploy(
lmdeploy_engine, template, request_list, generation_info=generation_info, use_tqdm=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/custom/test_lmdeploy_vlm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def test_lmdeploy_vlm():
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from swift.llm import (ModelType, get_lmdeploy_engine, get_default_template_type, get_template, inference_lmdeploy,
inference_stream_lmdeploy)
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_lmdeploy_vlm():
print(generation_info)

# batched
n_batched = 100
n_batched = 1000
request_list = [{
'query':
'这两张图片有什么区别:'
Expand Down
2 changes: 1 addition & 1 deletion tests/custom/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_vllm():
print(generation_info)

# batched
n_batched = 100
n_batched = 1000
request_list = [{'query': '晚上睡不着觉怎么办?'} for i in range(n_batched)]
resp_list = inference_vllm(llm_engine, template, request_list, generation_info=generation_info, use_tqdm=True)
assert len(resp_list) == n_batched
Expand Down
2 changes: 1 addition & 1 deletion tests/custom/test_vllm_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_vllm_vlm():
print(generation_info)

# batched
n_batched = 100
n_batched = 1000
images = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png']
request_list = [{'query': 'Describe this image.', 'images': images} for i in range(n_batched)]
resp_list = inference_vllm(llm_engine, template, request_list, generation_info=generation_info, use_tqdm=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/llm/test_run2.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def test_glm4v_9b_chat(self):
output = sft_main(
SftArguments(
model_type=ModelType.glm4v_9b_chat,
# dataset=DatasetName.capcha_images,
lora_target_modules='ALL',
# dataset=DatasetName.capcha_images,
# lora_target_modules='ALL',
train_dataset_sample=100,
eval_steps=5,
custom_train_dataset_path=[os.path.join(folder, 'multi_modal_3.jsonl')],
Expand Down

0 comments on commit 2f0dceb

Please sign in to comment.