From 2f0dceb62ef6ffb46cc8165e9cc8d06c9f18c5c4 Mon Sep 17 00:00:00 2001 From: Jintao Date: Mon, 5 Aug 2024 16:06:18 +0800 Subject: [PATCH] support max_batch_size (#1599) --- README.md | 2 +- README_CN.md | 2 +- ...44\350\241\214\345\217\202\346\225\260.md" | 2 +- docs/source_en/LLM/Command-line-parameters.md | 2 +- requirements/framework.txt | 4 +- swift/llm/utils/argument.py | 2 +- swift/llm/utils/lmdeploy_utils.py | 44 ++++++++++++++--- swift/llm/utils/vllm_utils.py | 48 ++++++++++++++++--- tests/custom/test_lmdeploy.py | 4 +- tests/custom/test_lmdeploy_vlm.py | 4 +- tests/custom/test_vllm.py | 2 +- tests/custom/test_vllm_vlm.py | 2 +- tests/llm/test_run2.py | 4 +- 13 files changed, 93 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index fcd2eb94a..437dc88a7 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@

- + diff --git a/README_CN.md b/README_CN.md index ee399f500..7a5952fdb 100644 --- a/README_CN.md +++ b/README_CN.md @@ -15,7 +15,7 @@

- + diff --git "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 32b458e50..2be463f94 100644 --- "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -405,7 +405,7 @@ deploy参数继承了infer参数, 除此之外增加了以下参数: - `--ssl_keyfile`: 默认为`None`. - `--ssl_certfile`: 默认为`None`. - `--verbose`: 是否对请求内容进行打印, 默认为`True`. -- `--log_interval`: 对统计信息进行打印的间隔, 单位为秒. 默认为`0`, 表示不打印统计信息. +- `--log_interval`: 对统计信息进行打印的间隔, 单位为秒. 默认为`10`. 如果设置为`0`, 表示不打印统计信息. ## web-ui 参数 diff --git a/docs/source_en/LLM/Command-line-parameters.md b/docs/source_en/LLM/Command-line-parameters.md index a73e346f3..64ada4859 100644 --- a/docs/source_en/LLM/Command-line-parameters.md +++ b/docs/source_en/LLM/Command-line-parameters.md @@ -406,7 +406,7 @@ deploy parameters inherit from infer parameters, with the following added parame - `--ssl_keyfile`: Default is `None`. - `--ssl_certfile`: Default is `None`. - `--verbose`: Whether to print the request content. Defaults to `True`. -- `--log_interval`: Interval for printing statistical information, in seconds. Defaults to 0, meaning no statistical information will be printed. +- `--log_interval`: The interval for printing statistics, in seconds. Default is `10`. If set to `0`, it means statistics will not be printed. ## web-ui Parameters diff --git a/requirements/framework.txt b/requirements/framework.txt index 19c54feb9..07290e989 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -4,13 +4,11 @@ aiohttp attrdict binpacking dacite -datasets<2.19 einops -huggingface_hub<0.24 importlib_metadata jieba matplotlib -modelscope>=1.14 +modelscope[datasets]>=1.17 nltk numpy<2.0 oss2 diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index be5283439..535ca8db7 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -1433,7 +1433,7 @@ class DeployArguments(InferArguments): owned_by: str = 'swift' verbose: bool = True # Whether to log request_info - log_interval: int = 0 # Interval for printing global statistics + log_interval: int = 10 # Interval for printing global statistics def __post_init__(self): super().__post_init__() diff --git a/swift/llm/utils/lmdeploy_utils.py b/swift/llm/utils/lmdeploy_utils.py index 9f9dd5ea5..b3753f94d 100644 --- a/swift/llm/utils/lmdeploy_utils.py +++ b/swift/llm/utils/lmdeploy_utils.py @@ -160,7 +160,8 @@ def _prepare_lmdeploy_request(lmdeploy_engine: Union[AsyncEngine, VLAsyncEngine] use_tqdm: bool = False, **kwargs): for key in ['num_prompt_tokens', 'num_generated_tokens', 'num_samples']: - generation_info[key] = 0 + if key not in generation_info: + generation_info[key] = 0 if hasattr(lmdeploy_engine, 'vl_encoder'): lmdeploy_engine.vl_encoder._loop_task = None @@ -301,6 +302,7 @@ def inference_lmdeploy(lmdeploy_engine: Union[AsyncEngine, VLAsyncEngine], *, generation_config: Optional[LmdeployGenerationConfig] = None, generation_info: Optional[Dict[str, Any]] = None, + max_batch_size: Optional[int] = None, use_tqdm: bool = False, verbose: bool = False, prompt_prefix: str = '[PROMPT]', @@ -309,15 +311,45 @@ def inference_lmdeploy(lmdeploy_engine: Union[AsyncEngine, VLAsyncEngine], if len(request_list) == 0: return [] runtime = time.perf_counter() + + is_multimodal = getattr(lmdeploy_engine, 'is_multimodal', False) + if is_multimodal and max_batch_size is None: + max_batch_size = 512 + + _inner_call = kwargs.get('_inner_call', False) + if generation_info is None: + generation_info = {} + elif not _inner_call: + generation_info.clear() + if max_batch_size is not None and len(request_list) > max_batch_size: + i = 0 + resp_list = [] + kwargs['_inner_call'] = True + while i < len(request_list): + resp_list += inference_lmdeploy( + lmdeploy_engine, + template, + request_list[i:i + max_batch_size], + generation_config=generation_config, + generation_info=generation_info, + max_batch_size=max_batch_size, + use_tqdm=use_tqdm, + verbose=verbose, + prompt_prefix=prompt_prefix, + output_prefix=output_prefix, + **kwargs) + i += max_batch_size + runtime = time.perf_counter() - runtime + generation_info['runtime'] = runtime + generation_info['samples/s'] = generation_info['num_samples'] / runtime + generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime + return resp_list + if generation_config is None: generation_config = getattr(lmdeploy_engine, 'generation_config', LmdeployGenerationConfig()) assert isinstance(generation_config, LmdeployGenerationConfig) request_list = deepcopy(request_list) generation_config = deepcopy(generation_config) - if generation_info is None: - generation_info = {} - else: - generation_info.clear() resp_list, generators = _prepare_lmdeploy_request( lmdeploy_engine, @@ -366,7 +398,7 @@ async def _batch_infer() -> None: prog_bar.close() runtime = time.perf_counter() - runtime generation_info['runtime'] = runtime - generation_info['samples/s'] = len(generators) / runtime + generation_info['samples/s'] = generation_info['num_samples'] / runtime generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime return resp_list diff --git a/swift/llm/utils/vllm_utils.py b/swift/llm/utils/vllm_utils.py index c2d6e0832..b9effcd83 100644 --- a/swift/llm/utils/vllm_utils.py +++ b/swift/llm/utils/vllm_utils.py @@ -290,7 +290,8 @@ def _prepare_vllm_request(llm_engine: LLMEngine, use_tqdm: bool = False, **kwargs) -> Tuple[List[Optional[Dict[str, Any]]], List[Tuple[bool, int]]]: for key in ['num_prompt_tokens', 'num_generated_tokens', 'num_samples']: - generation_info[key] = 0 + if key not in generation_info: + generation_info[key] = 0 template.model = llm_engine tokenizer = template.tokenizer @@ -457,6 +458,7 @@ def inference_vllm(llm_engine: LLMEngine, *, generation_config: Optional[VllmGenerationConfig] = None, generation_info: Optional[Dict[str, Any]] = None, + max_batch_size: Optional[int] = None, lora_request: Optional['LoRARequest'] = None, use_tqdm: bool = False, verbose: bool = False, @@ -473,16 +475,48 @@ def inference_vllm(llm_engine: LLMEngine, if len(request_list) == 0: return [] runtime = time.perf_counter() + + is_multimodal = getattr(llm_engine, 'is_multimodal', False) + if is_multimodal and max_batch_size is None: + max_batch_size = 512 + + _inner_call = kwargs.get('_inner_call', False) + if generation_info is None: + generation_info = {} + elif not _inner_call: + generation_info.clear() + if max_batch_size is not None and len(request_list) > max_batch_size: + i = 0 + resp_list = [] + kwargs['_inner_call'] = True + while i < len(request_list): + resp_list += inference_vllm( + llm_engine, + template, + request_list[i:i + max_batch_size], + generation_config=generation_config, + generation_info=generation_info, + max_batch_size=max_batch_size, + lora_request=lora_request, + use_tqdm=use_tqdm, + verbose=verbose, + prompt_prefix=prompt_prefix, + output_prefix=output_prefix, + **kwargs) + i += max_batch_size + runtime = time.perf_counter() - runtime + generation_info['runtime'] = runtime + generation_info['samples/s'] = generation_info['num_samples'] / runtime + generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime + return resp_list + if generation_config is None: generation_config = getattr(llm_engine, 'generation_config', VllmGenerationConfig()) assert isinstance(generation_config, VllmGenerationConfig) request_list = deepcopy(request_list) generation_config = deepcopy(generation_config) - if generation_info is None: - generation_info = {} - else: - generation_info.clear() + old_num_samples = generation_info.get('num_samples', 0) resp_list, agent_state = _prepare_vllm_request( llm_engine, template, @@ -496,7 +530,7 @@ def inference_vllm(llm_engine: LLMEngine, tokenizer = template.tokenizer if use_tqdm: assert verbose is False - prog_bar = tqdm(total=generation_info['num_samples'], dynamic_ncols=True, disable=not use_tqdm) + prog_bar = tqdm(total=generation_info['num_samples'] - old_num_samples, dynamic_ncols=True, disable=not use_tqdm) outputs = [] while llm_engine.has_unfinished_requests(): step_outputs = llm_engine.step() @@ -525,7 +559,7 @@ def inference_vllm(llm_engine: LLMEngine, print(tokenizer.decode(output.outputs[0].token_ids, False)) runtime = time.perf_counter() - runtime generation_info['runtime'] = runtime - generation_info['samples/s'] = len(outputs) / runtime + generation_info['samples/s'] = generation_info['num_samples'] / runtime generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime return resp_list diff --git a/tests/custom/test_lmdeploy.py b/tests/custom/test_lmdeploy.py index 72aab134e..0bab8e17a 100644 --- a/tests/custom/test_lmdeploy.py +++ b/tests/custom/test_lmdeploy.py @@ -1,6 +1,6 @@ def test_lmdeploy(): import os - os.environ['CUDA_VISIBLE_DEVICES'] = '0' + os.environ['CUDA_VISIBLE_DEVICES'] = '1' from swift.llm import (ModelType, get_lmdeploy_engine, get_default_template_type, get_template, inference_lmdeploy, inference_stream_lmdeploy) @@ -40,7 +40,7 @@ def test_lmdeploy(): print(generation_info) # batched - n_batched = 100 + n_batched = 1000 request_list = [{'query': '晚上睡不着觉怎么办?'} for i in range(n_batched)] resp_list = inference_lmdeploy( lmdeploy_engine, template, request_list, generation_info=generation_info, use_tqdm=True) diff --git a/tests/custom/test_lmdeploy_vlm.py b/tests/custom/test_lmdeploy_vlm.py index 35ed03863..6d31203b3 100644 --- a/tests/custom/test_lmdeploy_vlm.py +++ b/tests/custom/test_lmdeploy_vlm.py @@ -1,6 +1,6 @@ def test_lmdeploy_vlm(): import os - os.environ['CUDA_VISIBLE_DEVICES'] = '0' + os.environ['CUDA_VISIBLE_DEVICES'] = '1' from swift.llm import (ModelType, get_lmdeploy_engine, get_default_template_type, get_template, inference_lmdeploy, inference_stream_lmdeploy) @@ -47,7 +47,7 @@ def test_lmdeploy_vlm(): print(generation_info) # batched - n_batched = 100 + n_batched = 1000 request_list = [{ 'query': '这两张图片有什么区别:' diff --git a/tests/custom/test_vllm.py b/tests/custom/test_vllm.py index f07267a9b..9999375e2 100644 --- a/tests/custom/test_vllm.py +++ b/tests/custom/test_vllm.py @@ -40,7 +40,7 @@ def test_vllm(): print(generation_info) # batched - n_batched = 100 + n_batched = 1000 request_list = [{'query': '晚上睡不着觉怎么办?'} for i in range(n_batched)] resp_list = inference_vllm(llm_engine, template, request_list, generation_info=generation_info, use_tqdm=True) assert len(resp_list) == n_batched diff --git a/tests/custom/test_vllm_vlm.py b/tests/custom/test_vllm_vlm.py index baf3b61d5..aeddb99e0 100644 --- a/tests/custom/test_vllm_vlm.py +++ b/tests/custom/test_vllm_vlm.py @@ -43,7 +43,7 @@ def test_vllm_vlm(): print(generation_info) # batched - n_batched = 100 + n_batched = 1000 images = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png'] request_list = [{'query': 'Describe this image.', 'images': images} for i in range(n_batched)] resp_list = inference_vllm(llm_engine, template, request_list, generation_info=generation_info, use_tqdm=True) diff --git a/tests/llm/test_run2.py b/tests/llm/test_run2.py index 2a34805c6..b0a1ab73c 100644 --- a/tests/llm/test_run2.py +++ b/tests/llm/test_run2.py @@ -97,8 +97,8 @@ def test_glm4v_9b_chat(self): output = sft_main( SftArguments( model_type=ModelType.glm4v_9b_chat, - # dataset=DatasetName.capcha_images, - lora_target_modules='ALL', + # dataset=DatasetName.capcha_images, + # lora_target_modules='ALL', train_dataset_sample=100, eval_steps=5, custom_train_dataset_path=[os.path.join(folder, 'multi_modal_3.jsonl')],