From 8c586509dcf1dc9a6111d86449c3634ea0066c81 Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:21:18 +0800 Subject: [PATCH 1/6] [docs] Update docs. (#597) --- .../source/getting_started/installation.rst | 11 +- docs/CN/source/getting_started/quickstart.rst | 18 +- docs/CN/source/models/test.rst | 57 ++-- docs/CN/source/server/api_server_args_zh.rst | 245 +----------------- .../source/getting_started/installation.rst | 8 +- docs/EN/source/getting_started/quickstart.rst | 18 +- docs/EN/source/models/test.rst | 37 +-- docs/EN/source/server/api_server_args.rst | 243 +---------------- 8 files changed, 61 insertions(+), 576 deletions(-) diff --git a/docs/CN/source/getting_started/installation.rst b/docs/CN/source/getting_started/installation.rst index 74fd06539..45e967a24 100755 --- a/docs/CN/source/getting_started/installation.rst +++ b/docs/CN/source/getting_started/installation.rst @@ -67,8 +67,11 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton $ git clone https://github.com/ModelTC/lightllm.git $ cd lightllm $ - $ # 安装lightllm的依赖 - $ pip install -r requirements.txt + $ # 安装lightllm的依赖 (cuda 11.8) + $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 + $ + $ # 这个版本的 nccl 可以支持 torch cuda graph + $ pip install nvidia-nccl-cu12==2.20.5 $ $ # 安装lightllm $ python setup.py install @@ -76,11 +79,11 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton .. note:: Lightllm 的代码在多种GPU上都进行了测试,包括 V100, A100, A800, 4090, 和 H800。 - 如果你使用 A100 、A800 等显卡,那么推荐你安装 triton==2.1.0 : + 如果你使用 A100 、A800 等显卡,那么推荐你安装 triton==3.0.0 : .. code-block:: console - $ pip install triton==2.1.0 --no-deps + $ pip install triton==3.0.0 --no-deps 如果你使用 H800、V100 等显卡,那么推荐你安装 triton-nightly: diff --git a/docs/CN/source/getting_started/quickstart.rst b/docs/CN/source/getting_started/quickstart.rst index 3eaf34e84..eaf664748 100755 --- a/docs/CN/source/getting_started/quickstart.rst +++ b/docs/CN/source/getting_started/quickstart.rst @@ -17,7 +17,7 @@ 1. 准备模型文件 ------------------------- -下面的内容将会以 `Qwen2-0.5B `_ 演示lightllm对大语言模型的支持。 +下面的内容将会以 `Llama-2-7b-chat `_ 演示lightllm对大语言模型的支持。 下载模型的方法可以参考文章:`如何快速下载huggingface模型——全方法总结 `_ 下面是下载模型的实例代码: @@ -38,7 +38,7 @@ .. code-block:: console - $ huggingface-cli download Qwen/Qwen2-0.5B --local-dir Qwen2-0.5 + $ huggingface-cli download meta-llama/Llama-2-7b-chat --local-dir Llama-2-7b-chat .. tip:: 上面的下载模型的代码需要科学上网,并且需要花费一定的时间,你可以使用其它下载方式或者其它支持的模型作为替代。最新的支持的模型的列表请查看 `项目主页 `_ 。 @@ -47,20 +47,14 @@ 2. 启动模型服务 ------------------------- -下载完Qwen2-0.5B模型以后,在终端使用下面的代码部署API服务: +下载完Llama-2-7b-chat模型以后,在终端使用下面的代码部署API服务: .. code-block:: console - $ python -m lightllm.server.api_server --model_dir ~/models/Qwen2-0.5B \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ - $ --trust_remote_code \ - $ --eos_id 151643 + $ python -m lightllm.server.api_server --model_dir ~/models/Llama-2-7b-chat .. note:: - 上面代码中的 ``--model_dir`` 参数需要修改为你本机实际的模型路径。 ``--eos_id 151643`` 是Qwen模型专属,其它模型请删除这个参数。 + 上面代码中的 ``--model_dir`` 参数需要修改为你本机实际的模型路径。 3. (可选)测试模型服务 @@ -70,7 +64,7 @@ .. code-block:: console - $ curl http://localhost:8080/generate \ + $ curl http://localhost:8000/generate \ $ -H "Content-Type: application/json" \ $ -d '{ $ "inputs": "What is AI?", diff --git a/docs/CN/source/models/test.rst b/docs/CN/source/models/test.rst index 424a2532c..f094010f2 100755 --- a/docs/CN/source/models/test.rst +++ b/docs/CN/source/models/test.rst @@ -8,20 +8,14 @@ Qwen2-0.5B .. code-block:: console - $ python -m lightllm.server.api_server --model_dir ~/models/Qwen2-0.5B \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ - $ --trust_remote_code \ - $ --eos_id 151643 + $ python -m lightllm.server.api_server --model_dir ~/models/Qwen2-0.5B --trust_remote_code **测试服务** .. code-block:: console - $ curl http://localhost:8080/generate \ + $ curl http://localhost:8000/generate \ $ -H "Content-Type: application/json" \ $ -d '{ $ "inputs": "What is AI?", @@ -39,13 +33,10 @@ Qwen-VL-Chat .. code-block:: console - $ python -m lightllm.server.api_server --model_dir ~/models/Qwen-VL-Chat \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ - $ --trust_remote_code \ - $ --enable_multimodal + $ python -m lightllm.server.api_server + $ --model_dir ~/models/Qwen-VL-Chat \ + $ --trust_remote_code \ + $ --enable_multimodal **测试服务** @@ -79,7 +70,7 @@ Qwen-VL-Chat } } - url = "http://127.0.0.1:8080/generate" + url = "http://127.0.0.1:8000/generate" headers = {'Content-Type': 'application/json'} response = requests.post(url, headers=headers, data=json.dumps(data)) return response @@ -114,11 +105,7 @@ llama2-70b-chat .. code-block:: console - $ python -m lightllm.server.api_server --model_dir ~/models/llama2-70b-chat \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 4 \ - $ --max_total_token_num 120000 + $ python -m lightllm.server.api_server --model_dir ~/models/llama2-70b-chat --tp 4 .. tip:: @@ -128,7 +115,7 @@ llama2-70b-chat .. code-block:: console - $ curl http://localhost:8080/generate \ + $ curl http://localhost:8000/generate \ $ -H "Content-Type: application/json" \ $ -d '{ $ "inputs": "What is LLM?", @@ -146,13 +133,10 @@ internlm2-1_8b .. code-block:: console - $ python -m lightllm.server.api_server --model_dir ~/models/internlm2-1_8b \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ - $ --splitfuse_mode \ - $ --trust_remote_code + $ python -m lightllm.server.api_server + $ --model_dir ~/models/internlm2-1_8b \ + $ --splitfuse_mode \ + $ --trust_remote_code .. tip:: @@ -163,7 +147,7 @@ internlm2-1_8b .. code-block:: console - $ curl http://localhost:8080/generate \ + $ curl http://localhost:8000/generate \ $ -H "Content-Type: application/json" \ $ -d '{ $ "inputs": "What is LLM?", @@ -181,13 +165,10 @@ internlm2-1_8b-reward .. code-block:: console - $ python -m lightllm.server.api_server --model_dir ~/models/internlm2-1_8b-reward \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ - $ --use_reward_model \ - $ --trust_remote_code + $ python -m lightllm.server.api_server + $ --model_dir ~/models/internlm2-1_8b-reward \ + $ --use_reward_model \ + $ --trust_remote_code .. tip:: @@ -203,7 +184,7 @@ internlm2-1_8b-reward query = "<|im_start|>user\nHello! What's your name?<|im_end|>\n<|im_start|>assistant\nMy name is InternLM2! A helpful AI assistant. What can I do for you?<|im_end|>\n<|reward|>" - url = "http://127.0.0.1:8080/get_score" + url = "http://127.0.0.1:8000/get_score" headers = {'Content-Type': 'application/json'} data = { diff --git a/docs/CN/source/server/api_server_args_zh.rst b/docs/CN/source/server/api_server_args_zh.rst index cac768b60..af68bddef 100755 --- a/docs/CN/source/server/api_server_args_zh.rst +++ b/docs/CN/source/server/api_server_args_zh.rst @@ -5,243 +5,8 @@ APIServer 参数详解 使用方法 ++++++++++++ -.. code-block:: console - - python -m lightllm.server.api_server [-h] [--host HOST] [--port PORT] [--model_dir MODEL_DIR] - [--tokenizer_mode TOKENIZER_MODE] [--load_way LOAD_WAY] - [--max_total_token_num MAX_TOTAL_TOKEN_NUM] - [--batch_max_tokens BATCH_MAX_TOKENS] [--eos_id EOS_ID [EOS_ID ...]] - [--running_max_req_size RUNNING_MAX_REQ_SIZE] [--tp TP] - [--max_req_input_len MAX_REQ_INPUT_LEN] - [--max_req_total_len MAX_REQ_TOTAL_LEN] [--nccl_port NCCL_PORT] - [--mode MODE [MODE ...]] [--trust_remote_code] [--disable_log_stats] - [--log_stats_interval LOG_STATS_INTERVAL] - [--router_token_ratio ROUTER_TOKEN_RATIO] - [--router_max_new_token_len ROUTER_MAX_NEW_TOKEN_LEN] - [--router_max_wait_tokens ROUTER_MAX_WAIT_TOKENS] - [--use_dynamic_prompt_cache] - [--splitfuse_block_size SPLITFUSE_BLOCK_SIZE] [--splitfuse_mode] - [--beam_mode] [--diverse_mode] [--token_healing_mode] - [--enable_multimodal] [--cache_capacity CACHE_CAPACITY] - [--cache_reserved_ratio CACHE_RESERVED_RATIO] - [--data_type {fp16,float16,bf16,bfloat16,fp32,float32}] - [--return_all_prompt_logprobs] [--use_reward_model] - [--long_truncation_mode {None,head,center}] [--use_tgi_api] - [--health_monitor] [--metric_gateway METRIC_GATEWAY] - [--job_name JOB_NAME] [--grouping_key GROUPING_KEY] - [--push_interval PUSH_INTERVAL] [--enable_monitor_auth] - -参数说明 -++++++++ - -:code:`--host` - 服务IP地址 - - 默认值:127.0.0.1 - -:code:`--port` - 服务端口 - - 默认值:8000 - -:code:`--model_dir` - 模型权重目录路径,将从该目录加载配置、权重和分词器 - -:code:`--tokenizer_mode` - tokenizer加载模式,可以是 ``slow`` 、 ``fast`` 或 ``auto`` ,慢速模式加载快但运行慢,慢速模式有利于调试和测试,快速模式可以获得最佳性能,自动模式将尝试使用快速模式,如果失败将使用慢速模式 - - 默认值:slow - -:code:`--load_way` - 加载权重的方式,默认为 ``HF`` (huggingface 格式),llama模型也支持 ``DS`` (Deepspeed) - - 默认值:HF - -:code:`--max_total_token_num` - GPU能支持的总token数,等于 max_batch * (input_len + output_len) - - 默认值:6000 - -:code:`--eos_id` - 模型终止输出的 token id - - 默认值:[2] - -:code:`--running_max_req_size` - 同一时间内进行推理的最大请求数 - - 默认值:1000 - -:code:`--tp` - 模型进行张量并行的尺寸 - - 默认值:1 - -:code:`--max_req_input_len` - 单个请求最大的输入token量 - - 默认值:2048 - -:code:`--max_req_total_len` - 单个请求最大的输入token量+输出token量 - - 默认值:3072 - -:code:`--nccl_port` - 创建pytorch分布式环境使用的nccl端口 - - 默认值:28765 - -:code:`--mode` - 一个列表,用来对某些适配的模型开启某些算子从而进行加速,可选的方案包括: - - :code:`[triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding | triton_gqa_attention | triton_gqa_flashdecoding | triton_w4a16 | triton_w8a16 | triton_w8a8 | lmdeploy_w4a16 | ppl_w4a16 | ppl_w8a8 | ppl_w8a8_mixdown]` - - 其中, - - * triton_flashdecoding : 适用于长文本,当前支持的模型包括 llama/llama2/qwen - * triton_gqa_attention 和 triton_gqa_flashdecoding :使用于使用GQA的模型 - * triton_int8kv :使用int8来存储 kv cache,可以提升可容纳的token总数 - * ppl_int8kv :使用int8来存储 kv cache,并且使用 ppl 核函数进行加速 - * ppl_fp16 :使用 ppl 的 fp16精度的 decode attention 核函数 - * triton_int8weight:使用int8来存储参数 - * triton_int4weight 和 lmdeploy_int4weight 和 ppl_int4weight:使用int4来存储参数 - - .. tip:: - 在你使用某种模式之前,你需要先阅读源码确保该模式支持你想要使用的模型。 - - 默认值:[] - -:code:`--trust_remote_code` - 是否允许使用Hub仓库中上传者在上传文件中自定义的模型 - - 默认值:False - -:code:`--disable_log_stats` - 禁用日志系统记录吞吐量统计信息。 - - 默认值:False - -:code:`--log_stats_interval` - 以秒为单位的日志统计间隔。 - - 默认值:0.0 - -:code:`--router_token_ratio` - 控制router调度的token的比例 - - 默认值:0.0 - -:code:`--router_max_new_token_len` - 对于Router的请求最大的新Token量 - - 默认值:1024 - - -:code:`--router_max_wait_tokens` - 每次在进行或者等待 router_max_wait_tokens 轮次以后,router 就会调度新请求 - - 默认值:10 - -:code:`--use_dynamic_prompt_cache` - 是否使用 ``use_dynamic_prompt_cache`` - - 默认值:False - -:code:`--splitfuse_block_size` - splitfuse 块大小 - - 默认值:256 - -:code:`--splitfuse_mode` - 是否使用 ``splitfuse`` 模式 - - 默认值:False - -:code:`--beam_mode` - 是否使用 ``beamsearch`` 模式 - - 默认值:False - -:code:`--diverse_mode` - 是否使用 ``diversity generation`` 模式 - - 默认值:False - -:code:`--token_healing_mode` - 是否使用 ``code model infer`` 模式 - - 默认值:False - -:code:`--enable_multimodal` - 是否使用多模态模型 - - 默认值:False - -:code:`--cache_capacity` - 多模态资源缓存服务器的最大缓存量 - - 默认值:200 - -:code:`--cache_reserved_ratio` - 清除后资源后,缓存服务器预留容量的比例 - - 默认值:0.5 - -:code:`--data_type` - 模型权重的数据格式,可能的选择:fp16, float16, bf16, bfloat16, fp32, float32 - - 默认值:“float16” - -:code:`--return_all_prompt_logprobs` - 是否返回每个提示中所有标记的对数概率 - - 默认值:False - -:code:`--use_reward_model` - 是否使用 reward 类模型 - - 默认值:False - -:code:`--long_truncation_mode` - 用于选择对于过长的输入的处理方式,有如下的选择; - - * None : 返回异常 - * head :移除起始的一些token - * center:移除中间的某些token - - 默认值:None - -:code:`--use_tgi_api` - 使用 tgi 的输入和输出格式 - - 默认值:False - -:code:`--health_monitor` - 是否开启健康检查,健康检查会不断检查服务器的健康状况,并在出错时进行重启 - - 默认值:False - -:code:`--metric_gateway` - 对指标进行监控的IP地址 - - -:code:`--job_name` - 监视器的作业名称 - - 默认值:“lightllm” - -:code:`--grouping_key` - 监视器的 grouping_key,格式为 key=value - - 默认值:[] - -:code:`--push_interval` - 以秒为单位的推送监控指标的时间间隔 - - 默认值:10 - -:code:`--enable_monitor_auth` - 是否开启push_gateway的认证 - - 默认值:False \ No newline at end of file +.. argparse:: + :module: lightllm.server.api_server + :func: make_argument_parser + :prog: python -m lightllm.server.api_server + :nodefaultconst: diff --git a/docs/EN/source/getting_started/installation.rst b/docs/EN/source/getting_started/installation.rst index 03af4ac98..41efcd5a2 100755 --- a/docs/EN/source/getting_started/installation.rst +++ b/docs/EN/source/getting_started/installation.rst @@ -71,7 +71,9 @@ You can also install Lightllm from source: $ cd lightllm $ $ # Install Lightllm's dependencies - $ pip install -r requirements.txt + $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 + $ + $ pip install nvidia-nccl-cu12==2.20.5 $ $ # Install Lightllm $ python setup.py install @@ -79,11 +81,11 @@ You can also install Lightllm from source: .. note:: The Lightllm code has been tested on various GPUs, including V100, A100, A800, 4090, and H800. - If you are using A100, A800, or similar GPUs, it is recommended to install triton==2.1.0: + If you are using A100, A800, or similar GPUs, it is recommended to install triton==3.0.0: .. code-block:: console - $ pip install triton==2.1.0 --no-deps + $ pip install triton==3.0.0 --no-deps If you are using H800, V100, or similar GPUs, it is recommended to install triton-nightly: diff --git a/docs/EN/source/getting_started/quickstart.rst b/docs/EN/source/getting_started/quickstart.rst index 4233d4287..78410fd31 100755 --- a/docs/EN/source/getting_started/quickstart.rst +++ b/docs/EN/source/getting_started/quickstart.rst @@ -15,7 +15,7 @@ Deploying a model with Lightllm is very straightforward and requires only two st 1. Prepare the Model File ------------------------- -The following content will demonstrate Lightllm's support for large language models using `Qwen2-0.5B `_. You can refer to the article: `How to Quickly Download Hugging Face Models — A Summary of Methods `_ for methods to download models. +The following content will demonstrate Lightllm's support for large language models using `Llama-2-7b-chat `_. You can refer to the article: `How to Quickly Download Hugging Face Models — A Summary of Methods `_ for methods to download models. Here is an example of how to download the model: @@ -35,7 +35,7 @@ Here is an example of how to download the model: .. code-block:: console - $ huggingface-cli download Qwen/Qwen2-0.5B --local-dir Qwen2-0.5 + $ huggingface-cli download meta-llama/Llama-2-7b-chat --local-dir Llama-2-7b-chat .. tip:: The above code for downloading the model requires a stable internet connection and may take some time. You can use alternative download methods or other supported models as substitutes. For the latest list of supported models, please refer to the `project homepage `_. @@ -44,20 +44,14 @@ Here is an example of how to download the model: 2. Start the Model Service --------------------------- -After downloading the Qwen2-0.5B model, use the following command in the terminal to deploy the API service: +After downloading the Llama-2-7b-chat model, use the following command in the terminal to deploy the API service: .. code-block:: console - $ python -m lightllm.server.api_server --model_dir ~/models/Qwen2-0.5B \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ - $ --trust_remote_code \ - $ --eos_id 151643 + $ python -m lightllm.server.api_server --model_dir ~/models/Llama-2-7b-chat .. note:: - The ``--model_dir`` parameter in the above command should be changed to the actual path of your model on your machine. The ``--eos_id 151643`` parameter is specific to the Qwen model; remove this parameter for other models. + The ``--model_dir`` parameter in the above command should be changed to the actual path of your model on your machine. 3. (Optional) Test the Model Service -------------------------------------- @@ -66,7 +60,7 @@ In a new terminal, use the following command to test the model service: .. code-block:: console - $ curl http://localhost:8080/generate \ + $ curl http://localhost:8000/generate \ $ -H "Content-Type: application/json" \ $ -d '{ $ "inputs": "What is AI?", diff --git a/docs/EN/source/models/test.rst b/docs/EN/source/models/test.rst index 661518367..8bae9901b 100755 --- a/docs/EN/source/models/test.rst +++ b/docs/EN/source/models/test.rst @@ -9,19 +9,14 @@ Qwen2-0.5B .. code-block:: console $ python -m lightllm.server.api_server --model_dir ~/models/Qwen2-0.5B \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ - $ --trust_remote_code \ - $ --eos_id 151643 + $ --trust_remote_code **Test Server** .. code-block:: console - $ curl http://localhost:8080/generate \ + $ curl http://localhost:8000/generate \ $ -H "Content-Type: application/json" \ $ -d '{ $ "inputs": "What is AI?", @@ -40,12 +35,8 @@ Qwen-VL-Chat .. code-block:: console $ python -m lightllm.server.api_server --model_dir ~/models/Qwen-VL-Chat \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ $ --trust_remote_code \ - $ --enable_multimodal + $ --enable_multimodal **Test Server** @@ -79,7 +70,7 @@ Qwen-VL-Chat } } - url = "http://127.0.0.1:8080/generate" + url = "http://127.0.0.1:8000/generate" headers = {'Content-Type': 'application/json'} response = requests.post(url, headers=headers, data=json.dumps(data)) return response @@ -114,11 +105,7 @@ llama2-70b-chat .. code-block:: console - $ python -m lightllm.server.api_server --model_dir ~/models/llama2-70b-chat \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 4 \ - $ --max_total_token_num 120000 + $ python -m lightllm.server.api_server --model_dir ~/models/llama2-70b-chat --tp 4 .. tip:: @@ -128,7 +115,7 @@ llama2-70b-chat .. code-block:: console - $ curl http://localhost:8080/generate \ + $ curl http://localhost:8000/generate \ $ -H "Content-Type: application/json" \ $ -d '{ $ "inputs": "What is LLM?", @@ -147,10 +134,6 @@ internlm2-1_8b .. code-block:: console $ python -m lightllm.server.api_server --model_dir ~/models/internlm2-1_8b \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ $ --splitfuse_mode \ $ --trust_remote_code @@ -163,7 +146,7 @@ internlm2-1_8b .. code-block:: console - $ curl http://localhost:8080/generate \ + $ curl http://localhost:8000/generate \ $ -H "Content-Type: application/json" \ $ -d '{ $ "inputs": "What is LLM?", @@ -182,10 +165,6 @@ internlm2-1_8b-reward .. code-block:: console $ python -m lightllm.server.api_server --model_dir ~/models/internlm2-1_8b-reward \ - $ --host 0.0.0.0 \ - $ --port 8080 \ - $ --tp 1 \ - $ --max_total_token_num 120000 \ $ --use_reward_model \ $ --trust_remote_code @@ -203,7 +182,7 @@ internlm2-1_8b-reward query = "<|im_start|>user\nHello! What's your name?<|im_end|>\n<|im_start|>assistant\nMy name is InternLM2! A helpful AI assistant. What can I do for you?<|im_end|>\n<|reward|>" - url = "http://127.0.0.1:8080/get_score" + url = "http://127.0.0.1:8000/get_score" headers = {'Content-Type': 'application/json'} data = { diff --git a/docs/EN/source/server/api_server_args.rst b/docs/EN/source/server/api_server_args.rst index f96f729b3..0fbe35b99 100755 --- a/docs/EN/source/server/api_server_args.rst +++ b/docs/EN/source/server/api_server_args.rst @@ -5,241 +5,8 @@ APIServer Args Usage ++++++++++++ -.. code-block:: console - - python -m lightllm.server.api_server [-h] [--host HOST] [--port PORT] [--model_dir MODEL_DIR] - [--tokenizer_mode TOKENIZER_MODE] [--load_way LOAD_WAY] - [--max_total_token_num MAX_TOTAL_TOKEN_NUM] - [--batch_max_tokens BATCH_MAX_TOKENS] [--eos_id EOS_ID [EOS_ID ...]] - [--running_max_req_size RUNNING_MAX_REQ_SIZE] [--tp TP] - [--max_req_input_len MAX_REQ_INPUT_LEN] - [--max_req_total_len MAX_REQ_TOTAL_LEN] [--nccl_port NCCL_PORT] - [--mode MODE [MODE ...]] [--trust_remote_code] [--disable_log_stats] - [--log_stats_interval LOG_STATS_INTERVAL] - [--router_token_ratio ROUTER_TOKEN_RATIO] - [--router_max_new_token_len ROUTER_MAX_NEW_TOKEN_LEN] - [--router_max_wait_tokens ROUTER_MAX_WAIT_TOKENS] - [--use_dynamic_prompt_cache] - [--splitfuse_block_size SPLITFUSE_BLOCK_SIZE] [--splitfuse_mode] - [--beam_mode] [--diverse_mode] [--token_healing_mode] - [--enable_multimodal] [--cache_capacity CACHE_CAPACITY] - [--cache_reserved_ratio CACHE_RESERVED_RATIO] - [--data_type {fp16,float16,bf16,bfloat16,fp32,float32}] - [--return_all_prompt_logprobs] [--use_reward_model] - [--long_truncation_mode {None,head,center}] [--use_tgi_api] - [--health_monitor] [--metric_gateway METRIC_GATEWAY] - [--job_name JOB_NAME] [--grouping_key GROUPING_KEY] - [--push_interval PUSH_INTERVAL] [--enable_monitor_auth] - -Arguments -++++++++++++++ - -:code:`--host` - Service IP address. - - Default : 127.0.0.1 - -:code:`--port` - Service port. - - Default : 8000 - -:code:`--model_dir` - The model weight dir path, the app will load config, weights and tokenizer from this dir. - -:code:`--tokenizer_mode` - Tokenizer load mode, can be slow, fast or auto, slow mode load fast but run slow, slow mode is good for debug and test, fast mode get best performance, auto mode will try to use fast mode, if failed will use slow mode. - - Default : slow - -:code:`--load_way` - the way of loading model weights, the default is HF(Huggingface format), llama also supports DS(Deepspeed) - - Default : HF - -:code:`--max_total_token_num` - the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len) - - Default : 6000 - -:code:`--batch_max_tokens` - max tokens num for new cat batch, it control prefill batch size to Preventing OOM - - Default : None - -:code:`--eos_id` - eos stop token id - - Default : [2] - -:code:`--running_max_req_size` - the max size for forward requests in the same time - - Default : 1000 - -:code:`--tp` - model tp parral size, the default is 1 - - Default : 1 - -:code:`--max_req_input_len` - the max value for req input tokens num - - Default : 2048 - -:code:`--max_req_total_len` - the max value for req_input_len + req_output_len - - Default : 2048 + 1024 - -:code:`--nccl_port` - the nccl_port to build a distributed environment for PyTorch - - Default : 28765 - -:code:`--mode` - Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding| triton_gqa_attention | triton_gqa_flashdecoding| triton_w4a16 | triton_w8a16 | triton_w8a8 | lmdeploy_w4a16| ppl_w4a16 | ppl_w8a8 | ppl_w8a8_mixdown], - - * triton_flashdecoding mode is for long context, current support llama llama2 qwen; - * triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; - * triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; - * ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; - * ppl_fp16 mode use ppl fast fp16 decode attention kernel; - * triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode - * use int8 and int4 to store weights; - * you need to read source code to make sure the supported detail mode for all models - - Default : [] - -:code:`--trust_remote_code` - Whether or not to allow for custom models defined on the Hub in their own modeling files. - - Default : False - -:code:`--disable_log_stats` - disable logging throughput stats. - - Default : False - -:code:`--log_stats_interval` - log stats interval in second. - - Default : 10 - -:code:`--router_token_ratio` - token ratio to control router dispatch - - Default : 0.0 - -:code:`--router_max_new_token_len` - the request max new token len for router - - Default : 1024 - -:code:`--router_max_wait_tokens` - schedule new requests after every router_max_wait_tokens decode steps. - - Default : 10 - -:code:`--use_dynamic_prompt_cache` - use_dynamic_prompt_cache test - - Default : False - -:code:`--splitfuse_block_size` - splitfuse block size - - Default : 256 - -:code:`--splitfuse_mode` - use ``splitfuse`` mode - - Default : False - -:code:`--beam_mode` - use ``beamsearch`` mode - - Default : False - -:code:`--diverse_mode` - use ``diversity generation`` mode - - Default : False - -:code:`--token_healing_mode` - use ``code model infer`` mode - - Default : False - -:code:`--enable_multimodal` - Whether or not to allow to load additional multimodal models. - - Default : False - -:code:`--cache_capacity` - cache server capacity for multimodal resources - - Default : 200 - -:code:`--cache_reserved_ratio` - cache server reserved capacity ratio after clear - - Default : 0.5 - -:code:`--data_type` - the data type of the model weight, choices : fp16, float16, bf16, bfloat16, fp32, float32 - - Default : “float16” - -:code:`--return_all_prompt_logprobs` - return_all_prompt_logprobs - - Default : False - -:code:`--use_reward_model` - use reward model. - - Default : False - -:code:`--long_truncation_mode` - use to select the handle way when input token len > max_req_input_len. - - * None : raise Exception - * head : remove some head tokens to make input token len <= max_req_input_len - * center : remove some tokens in center loc to make input token len <= max_req_input_len - - Default : None - -:code:`--use_tgi_api` - use tgi input and ouput format - - Default : False - -:code:`--health_monitor` - check the health of service and restart when error - - Default : False - -:code:`--metric_gateway` - address for collecting monitoring metrics - - -:code:`--job_name` - job name for monitor - - Default : “lightllm” - -:code:`--grouping_key` - grouping_key for the monitor in the form key=value - - Default : [] - -:code:`--push_interval` - interval of pushing monitoring metrics - - Default : 10 - -:code:`--enable_monitor_auth` - Whether to open authentication for push_gateway - - Default : False \ No newline at end of file +.. argparse:: + :module: lightllm.server.api_server + :func: make_argument_parser + :prog: python -m lightllm.server.api_server + :nodefaultconst: \ No newline at end of file From e7184fc3ed5b74717ff9cbfc129b8dc56edb484b Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:24:32 +0800 Subject: [PATCH 2/6] Allowed token ids (#598) --- .../server/router/model_infer/infer_batch.py | 10 ++++++++++ .../impl_for_simple_constraint_mode.py | 4 +++- lightllm/server/sampling_params.py | 18 ++++++++++++++++++ test/test_constraint_server.py | 10 ++++++++++ 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index c45e14019..1dd954fb7 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -39,6 +39,7 @@ def __init__( stop_sequences: List[List[int]] = [], input_penalty: bool = False, regular_constraint: Optional[str] = None, + allowed_token_ids: Optional[List[int]] = None, ) -> None: self.best_of = best_of self.do_sample = do_sample @@ -60,8 +61,17 @@ def __init__( self.regular_constraint = regular_constraint self.regex_guide = None self.fsm_current_state: int = 0 + self.allowed_token_ids = allowed_token_ids + # this check is not very good to placed here. to do... + if self.allowed_token_ids is not None: + if not all(e < vocab_size for e in self.allowed_token_ids): + logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids") + self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size] return + def has_constraint_setting(self) -> bool: + return self.regular_constraint is not None or self.allowed_token_ids is not None + class InferReq: def __init__( diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py index 660f0bf75..dfa6d9295 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py @@ -92,7 +92,7 @@ def decode_batch(self, batch_id): logits = self.model.forward(**kwargs) - all_has_no_constraint = all([e.sampling_param.regular_constraint is None for e in run_reqs]) + all_has_no_constraint = all([not e.sampling_param.has_constraint_setting() for e in run_reqs]) if not all_has_no_constraint: mask = torch.ones_like(logits, dtype=torch.bool) for i, run_obj in enumerate(run_reqs): @@ -146,6 +146,8 @@ def _mask_req_out_token(self, i, run_obj: InferReq, mask): regex_guide: RegexGuide = sample_params.regex_guide ok_token_id_list = regex_guide.get_next_instruction(sample_params.fsm_current_state).tokens mask[i, ok_token_id_list] = False + elif sample_params.allowed_token_ids is not None: + mask[i, sample_params.allowed_token_ids] = False else: mask[i, :] = False return diff --git a/lightllm/server/sampling_params.py b/lightllm/server/sampling_params.py index 35a856b6e..be7c538f5 100644 --- a/lightllm/server/sampling_params.py +++ b/lightllm/server/sampling_params.py @@ -32,6 +32,10 @@ def __init__( # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty input_penalty: bool = DEFAULT_INPUT_PENALTY, regular_constraint: Optional[str] = None, # Regular expressions constrain the output. + # If provided, the engine will construct a logits, + # processor which only retains scores for the given token ids. Defaults to None. + # allowed_token_ids only can be used in "--simple_constraint_mode" started server. + allowed_token_ids: Optional[List[int]] = None, ) -> None: self.best_of = best_of self.n = n @@ -51,6 +55,7 @@ def __init__( self.add_spaces_between_special_tokens = add_spaces_between_special_tokens self.print_eos_token = print_eos_token self.regular_constraint = regular_constraint + self.allowed_token_ids = allowed_token_ids if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -131,6 +136,18 @@ def verify(self): self._verify_stop_sentences() + self._verify_allowed_token_ids() + + return + + def _verify_allowed_token_ids(self): + if self.allowed_token_ids is not None: + if (not isinstance(self.allowed_token_ids, list)) or ( + not all(isinstance(token_id, int) for token_id in self.allowed_token_ids) + ): + raise ValueError(f"allowed_token_ids need format List[int], but get {self.allowed_token_ids}") + if self.regular_constraint is not None: + raise ValueError("allowed_token_ids and regular_constraint can not be used in same time") return def _verify_stop_sentences(self): @@ -187,4 +204,5 @@ def to_dict(self): ret["best_of"] = self.best_of ret["input_penalty"] = self.input_penalty ret["regular_constraint"] = self.regular_constraint + ret["allowed_token_ids"] = self.allowed_token_ids return ret diff --git a/test/test_constraint_server.py b/test/test_constraint_server.py index 3e4037a0c..62b622031 100644 --- a/test/test_constraint_server.py +++ b/test/test_constraint_server.py @@ -55,3 +55,13 @@ def run(self): } thread = RequestThread(url, headers, data) thread.start() + +time.sleep(10) + +for i in range(20): + data = { + "inputs": "Are dog a man? ", + "parameters": {"do_sample": False, "ignore_eos": True, "max_new_tokens": 200, "allowed_token_ids": [2, 3]}, + } + thread = RequestThread(url, headers, data) + thread.start() From e58aa74809922d3d52ae9365acb4628951ffbd9c Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:41:57 +0800 Subject: [PATCH 3/6] add first_token_constraint_mode (#599) Co-authored-by: wangzaijun --- lightllm/server/api_server.py | 8 ++- lightllm/server/router/manager.py | 1 + .../model_infer/mode_backend/__init__.py | 1 + .../impl_for_first_token_constraint_mode.py | 72 +++++++++++++++++++ .../server/router/model_infer/model_rpc.py | 4 ++ lightllm/server/router/req_queue/__init__.py | 2 + 6 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 00bfd0478..f72ef1dc3 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -435,7 +435,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") - + parser.add_argument( + "--first_token_constraint_mode", + action="store_true", + help="""constraint the first token allowed range, + use env FIRST_ALLOWED_TOKENS to set the range, like FIRST_ALLOWED_TOKENS=1,2 ..""", + ) parser.add_argument( "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models." ) @@ -546,6 +551,7 @@ def main(): args.token_healing_mode, args.use_reward_model, args.return_all_prompt_logprobs, + args.first_token_constraint_mode, ].count(True) <= 1 # 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。 if args.use_dynamic_prompt_cache: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 58546b27f..c7ef6dea1 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -91,6 +91,7 @@ async def wait_to_model_ready(self): "max_req_num": self.args.running_max_req_size + 8, "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 "nccl_port": self.args.nccl_port, + "is_first_token_constraint_mode": self.args.first_token_constraint_mode, "is_splitfuse_mode": self.is_splitfuse_mode, "splitfuse_block_size": self.splitfuse_block_size, "is_token_healing": self.args.token_healing_mode, diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index f54a333cf..d35c459e6 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -6,3 +6,4 @@ from .diverse_backend.impl import DiversehBackend from .continues_batch.impl_for_token_healing import TokenHealingBackend from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend +from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py new file mode 100644 index 000000000..8de334530 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py @@ -0,0 +1,72 @@ +import os +import shutil +import torch +from .impl import ContinuesBatchBackend +from lightllm.server.io_struct import FinishStatus +from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams +from .pre_process import prepare_prefill_inputs, prepare_decode_inputs +from .post_process import sample +from typing import List +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class FirstTokenConstraintBackend(ContinuesBatchBackend): + def __init__(self) -> None: + super().__init__() + + def init_custom(self): + first_allowed_tokens_strs: str = os.environ.get("FIRST_ALLOWED_TOKENS", None) + logger.info(f"first_allowed_tokens_strs : {first_allowed_tokens_strs}") + # 使用该模式需要设置FIRST_ALLOWED_TOKENS 环境变量,格式为 "1,2" 或 "1,2,3" 等数字字符串 + assert first_allowed_tokens_strs is not None + first_allowed_tokens_strs.split(",") + self.first_allowed_tokens = [int(e.strip()) for e in first_allowed_tokens_strs.split(",") if len(e.strip()) > 0] + logger.info(f"first_allowed_tokens : {self.first_allowed_tokens}") + # check token_id < vocab_size + assert all(e < self.model.vocab_size for e in self.first_allowed_tokens) + return + + def forward(self, batch_id, is_prefill): + output_dict = {} + batch: InferBatch = self.cache.pop(batch_id) + if is_prefill: + kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal) + else: + kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache) + + logits = self.model.forward(**kwargs) + # first token constraint + if is_prefill: + mask = torch.ones_like(logits, dtype=torch.bool) + mask[:, self.first_allowed_tokens] = False + logits[mask] = -1000000.0 + + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + # prefill and decode is same + req_obj: InferReq = req_obj + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + req_obj.update_finish_status(self.eos_id) + + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[req_obj.r_id] = ( + req_obj.req_status, + req_obj.cur_kv_len, + req_obj.get_output_len(), + [(int(next_token_id), metadata)], + req_obj.finish_status.value, # 转化为整数,避免传送大对象, + None, + ) # 请求状态, 当前占用的kv的长度, 当前输出token的数量, 输出的token的id和元信息列表, 是否推理结束的状态, 额外保留参数 + + self.cache[batch.batch_id] = batch + return output_dict diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 2e56b9317..2f6cc55c2 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -12,6 +12,7 @@ RewardModelBackend, TokenHealingBackend, SimpleConstraintBackend, + FirstTokenConstraintBackend, ) from lightllm.utils.log_utils import init_logger @@ -31,6 +32,7 @@ def exposed_init_model(self, kvargs): beam_mode = kvargs.get("beam_mode", False) diverse_mode = kvargs.get("diverse_mode", False) is_token_healing = kvargs.get("is_token_healing", False) + is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False) if kvargs.get("args", None) is not None: is_simple_constraint_mode = kvargs.get("args", None).simple_constraint_mode else: @@ -51,6 +53,8 @@ def exposed_init_model(self, kvargs): self.backend = TokenHealingBackend() elif is_simple_constraint_mode: self.backend = SimpleConstraintBackend() + elif is_first_token_constraint_mode: + self.backend = FirstTokenConstraintBackend() else: self.backend = ContinuesBatchBackend() diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 87b301da2..dd2e1e3a8 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -14,5 +14,7 @@ def build_req_queue(args, router): return ContinuesBatchQueue(args, router) if args.simple_constraint_mode: return ContinuesBatchQueue(args, router) + if args.first_token_constraint_mode: + return ContinuesBatchQueue(args, router) return ContinuesBatchQueue(args, router) From e5ad711c6e496449729e3fdb5144d133d9b895ef Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Fri, 1 Nov 2024 17:47:45 +0800 Subject: [PATCH 4/6] [docs] Update docs. (#602) --- docs/CN/README.md | 10 + docs/CN/requirements-docs.txt | 5 +- docs/CN/source/server/api_server_args_zh.rst | 2 +- docs/EN/README.md | 10 + docs/EN/requirements-docs.txt | 4 +- docs/EN/source/server/api_server_args.rst | 2 +- lightllm/server/api_cli.py | 196 ++++++++++++++++++ lightllm/server/api_server.py | 197 +------------------ 8 files changed, 221 insertions(+), 205 deletions(-) create mode 100644 lightllm/server/api_cli.py diff --git a/docs/CN/README.md b/docs/CN/README.md index a513813ee..a4375a7df 100755 --- a/docs/CN/README.md +++ b/docs/CN/README.md @@ -1,7 +1,17 @@ ## Build the docs +```bash +# Install lightllm + +# git clone https://github.com/ModelTC/lightllm.git +# cd lightllm +pip install --no-deps . +``` + ```bash # Install dependencies. + +# cd docs/CN pip install -r requirements-docs.txt # Build the docs. diff --git a/docs/CN/requirements-docs.txt b/docs/CN/requirements-docs.txt index fd353450f..f123d2ffd 100755 --- a/docs/CN/requirements-docs.txt +++ b/docs/CN/requirements-docs.txt @@ -8,8 +8,5 @@ sphinxcontrib.openapi # packages to install to build the documentation pydantic --f https://download.pytorch.org/whl/cpu -torch -py-cpuinfo -transformers openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +numpy \ No newline at end of file diff --git a/docs/CN/source/server/api_server_args_zh.rst b/docs/CN/source/server/api_server_args_zh.rst index af68bddef..41a9910ad 100755 --- a/docs/CN/source/server/api_server_args_zh.rst +++ b/docs/CN/source/server/api_server_args_zh.rst @@ -6,7 +6,7 @@ APIServer 参数详解 ++++++++++++ .. argparse:: - :module: lightllm.server.api_server + :module: lightllm.server.api_cli :func: make_argument_parser :prog: python -m lightllm.server.api_server :nodefaultconst: diff --git a/docs/EN/README.md b/docs/EN/README.md index a513813ee..de708d72a 100755 --- a/docs/EN/README.md +++ b/docs/EN/README.md @@ -1,7 +1,17 @@ ## Build the docs +```bash +# Install lightllm + +# git clone https://github.com/ModelTC/lightllm.git +# cd lightllm +pip install --no-deps . +``` + ```bash # Install dependencies. + +# cd docs/EN pip install -r requirements-docs.txt # Build the docs. diff --git a/docs/EN/requirements-docs.txt b/docs/EN/requirements-docs.txt index fd353450f..713b9e716 100755 --- a/docs/EN/requirements-docs.txt +++ b/docs/EN/requirements-docs.txt @@ -9,7 +9,5 @@ sphinxcontrib.openapi # packages to install to build the documentation pydantic -f https://download.pytorch.org/whl/cpu -torch -py-cpuinfo -transformers openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +numpy \ No newline at end of file diff --git a/docs/EN/source/server/api_server_args.rst b/docs/EN/source/server/api_server_args.rst index 0fbe35b99..98c1c61bb 100755 --- a/docs/EN/source/server/api_server_args.rst +++ b/docs/EN/source/server/api_server_args.rst @@ -6,7 +6,7 @@ Usage ++++++++++++ .. argparse:: - :module: lightllm.server.api_server + :module: lightllm.server.api_cli :func: make_argument_parser :prog: python -m lightllm.server.api_server :nodefaultconst: \ No newline at end of file diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py new file mode 100644 index 000000000..bf0daa12a --- /dev/null +++ b/lightllm/server/api_cli.py @@ -0,0 +1,196 @@ +import argparse + + +def make_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + + parser.add_argument( + "--model_name", + type=str, + default="default_model_name", + help="just help to distinguish internal model name, use 'host:port/get_model_name' to get", + ) + + parser.add_argument( + "--model_dir", + type=str, + default=None, + help="the model weight dir path, the app will load config, weights and tokenizer from this dir", + ) + parser.add_argument( + "--tokenizer_mode", + type=str, + default="slow", + help="""tokenizer load mode, can be slow, fast or auto, slow mode load fast but run slow, + slow mode is good for debug and test, fast mode get best performance, auto mode will + try to use fast mode, if failed will use slow mode""", + ) + parser.add_argument( + "--load_way", + type=str, + default="HF", + help="""the way of loading model weights, the default is HF(Huggingface format), llama also supports + DS(Deepspeed)""", + ) + parser.add_argument( + "--max_total_token_num", + type=int, + default=None, + help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)", + ) + parser.add_argument( + "--mem_fraction", + type=float, + default=0.9, + help="""Memory usage ratio, default is 0.9, you can specify a smaller value if OOM occurs at runtime. + If max_total_token_num is not specified, it will be calculated automatically based on this value.""", + ) + parser.add_argument( + "--batch_max_tokens", + type=int, + default=None, + help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", + ) + parser.add_argument( + "--eos_id", nargs="+", type=int, default=None, help="eos stop token id, if None, will load from config.json" + ) + parser.add_argument( + "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" + ) + parser.add_argument("--tp", type=int, default=1, help="model tp parral size, the default is 1") + parser.add_argument("--max_req_input_len", type=int, default=2048, help="the max value for req input tokens num") + parser.add_argument( + "--max_req_total_len", type=int, default=2048 + 1024, help="the max value for req_input_len + req_output_len" + ) + parser.add_argument( + "--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch" + ) + parser.add_argument( + "--mode", + type=str, + default=[], + nargs="+", + help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding + | triton_gqa_attention | triton_gqa_flashdecoding + | triton_w4a16 | triton_w8a16 | triton_w8a8 | lmdeploy_w4a16 + | ppl_w4a16 | ppl_w8a8 | ppl_w8a8_mixdown], + triton_flashdecoding mode is for long context, current support llama llama2 qwen; + triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; + triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; + ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; + ppl_fp16 mode use ppl fast fp16 decode attention kernel; + triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode + use int8 and int4 to store weights; + you need to read source code to make sure the supported detail mode for all models""", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", + ) + parser.add_argument("--disable_log_stats", action="store_true", help="disable logging throughput stats.") + parser.add_argument("--log_stats_interval", type=int, default=10, help="log stats interval in second.") + + parser.add_argument("--router_token_ratio", type=float, default=0.0, help="token ratio to control router dispatch") + parser.add_argument( + "--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router" + ) + + parser.add_argument( + "--router_max_wait_tokens", + type=int, + default=10, + help="schedule new requests after every router_max_wait_tokens decode steps.", + ) + + parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test") + + parser.add_argument("--splitfuse_block_size", type=int, default=256, help="splitfuse block size") + + parser.add_argument("--splitfuse_mode", action="store_true", help="use splitfuse mode") + parser.add_argument("--beam_mode", action="store_true", help="use beamsearch mode") + parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") + parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") + parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") + parser.add_argument( + "--first_token_constraint_mode", + action="store_true", + help="""constraint the first token allowed range, + use env FIRST_ALLOWED_TOKENS to set the range, like FIRST_ALLOWED_TOKENS=1,2 ..""", + ) + parser.add_argument( + "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models." + ) + parser.add_argument( + "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" + ) + parser.add_argument( + "--cache_reserved_ratio", type=float, default=0.5, help="cache server reserved capacity ratio after clear" + ) + parser.add_argument( + "--data_type", + type=str, + choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], + default=None, + help="the data type of the model weight", + ) + parser.add_argument("--return_all_prompt_logprobs", action="store_true", help="return all prompt tokens logprobs") + + parser.add_argument("--use_reward_model", action="store_true", help="use reward model") + + parser.add_argument( + "--long_truncation_mode", + type=str, + choices=[None, "head", "center"], + default=None, + help="""use to select the handle way when input token len > max_req_input_len. + None : raise Exception + head : remove some head tokens to make input token len <= max_req_input_len + center : remove some tokens in center loc to make input token len <= max_req_input_len""", + ) + parser.add_argument("--use_tgi_api", action="store_true", help="use tgi input and ouput format") + parser.add_argument( + "--health_monitor", action="store_true", help="check the health of service and restart when error" + ) + parser.add_argument("--metric_gateway", type=str, default=None, help="address for collecting monitoring metrics") + parser.add_argument("--job_name", type=str, default="lightllm", help="job name for monitor") + parser.add_argument( + "--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value" + ) + parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") + parser.add_argument( + "--visual_infer_batch_size", type=int, default=4, help="number of images to process in each inference batch" + ) + parser.add_argument( + "--visual_gpu_ids", nargs="+", type=int, default=[0], help="List of GPU IDs to use, e.g., 0 1 2" + ) + parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") + parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") + parser.add_argument( + "--visual_nccl_ports", + nargs="+", + type=int, + default=[29500], + help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", + ) + parser.add_argument( + "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" + ) + parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") + parser.add_argument( + "--graph_max_batch_size", + type=int, + default=16, + help="""Maximum batch size that can be captured by the cuda graph for decodign stage. + The default value is 8. It will turn into eagar mode if encounters a larger value.""", + ) + parser.add_argument( + "--graph_max_len_in_batch", + type=int, + default=8192, + help="""Maximum sequence length that can be captured by the cuda graph for decodign stage. + The default value is 8192. It will turn into eagar mode if encounters a larger value. """, + ) + return parser diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index f72ef1dc3..3018d0799 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -27,7 +27,6 @@ from .build_prompt import build_prompt asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -import argparse import json from http import HTTPStatus import uuid @@ -37,6 +36,7 @@ from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import Response, StreamingResponse, JSONResponse import uvicorn +from .api_cli import make_argument_parser from .sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager @@ -322,201 +322,6 @@ async def shutdown(): return -def make_argument_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("--port", type=int, default=8000) - - parser.add_argument( - "--model_name", - type=str, - default="default_model_name", - help="just help to distinguish internal model name, use 'host:port/get_model_name' to get", - ) - - parser.add_argument( - "--model_dir", - type=str, - default=None, - help="the model weight dir path, the app will load config, weights and tokenizer from this dir", - ) - parser.add_argument( - "--tokenizer_mode", - type=str, - default="slow", - help="""tokenizer load mode, can be slow, fast or auto, slow mode load fast but run slow, - slow mode is good for debug and test, fast mode get best performance, auto mode will - try to use fast mode, if failed will use slow mode""", - ) - parser.add_argument( - "--load_way", - type=str, - default="HF", - help="""the way of loading model weights, the default is HF(Huggingface format), llama also supports - DS(Deepspeed)""", - ) - parser.add_argument( - "--max_total_token_num", - type=int, - default=None, - help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)", - ) - parser.add_argument( - "--mem_fraction", - type=float, - default=0.9, - help="""Memory usage ratio, default is 0.9, you can specify a smaller value if OOM occurs at runtime. - If max_total_token_num is not specified, it will be calculated automatically based on this value.""", - ) - parser.add_argument( - "--batch_max_tokens", - type=int, - default=None, - help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", - ) - parser.add_argument( - "--eos_id", nargs="+", type=int, default=None, help="eos stop token id, if None, will load from config.json" - ) - parser.add_argument( - "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" - ) - parser.add_argument("--tp", type=int, default=1, help="model tp parral size, the default is 1") - parser.add_argument("--max_req_input_len", type=int, default=2048, help="the max value for req input tokens num") - parser.add_argument( - "--max_req_total_len", type=int, default=2048 + 1024, help="the max value for req_input_len + req_output_len" - ) - parser.add_argument( - "--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch" - ) - parser.add_argument( - "--mode", - type=str, - default=[], - nargs="+", - help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding - | triton_gqa_attention | triton_gqa_flashdecoding - | triton_w4a16 | triton_w8a16 | triton_w8a8 | lmdeploy_w4a16 - | ppl_w4a16 | ppl_w8a8 | ppl_w8a8_mixdown], - triton_flashdecoding mode is for long context, current support llama llama2 qwen; - triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; - triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; - ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; - ppl_fp16 mode use ppl fast fp16 decode attention kernel; - triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode - use int8 and int4 to store weights; - you need to read source code to make sure the supported detail mode for all models""", - ) - parser.add_argument( - "--trust_remote_code", - action="store_true", - help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", - ) - parser.add_argument("--disable_log_stats", action="store_true", help="disable logging throughput stats.") - parser.add_argument("--log_stats_interval", type=int, default=10, help="log stats interval in second.") - - parser.add_argument("--router_token_ratio", type=float, default=0.0, help="token ratio to control router dispatch") - parser.add_argument( - "--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router" - ) - - parser.add_argument( - "--router_max_wait_tokens", - type=int, - default=10, - help="schedule new requests after every router_max_wait_tokens decode steps.", - ) - - parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test") - - parser.add_argument("--splitfuse_block_size", type=int, default=256, help="splitfuse block size") - - parser.add_argument("--splitfuse_mode", action="store_true", help="use splitfuse mode") - parser.add_argument("--beam_mode", action="store_true", help="use beamsearch mode") - parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") - parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") - parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") - parser.add_argument( - "--first_token_constraint_mode", - action="store_true", - help="""constraint the first token allowed range, - use env FIRST_ALLOWED_TOKENS to set the range, like FIRST_ALLOWED_TOKENS=1,2 ..""", - ) - parser.add_argument( - "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models." - ) - parser.add_argument( - "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" - ) - parser.add_argument( - "--cache_reserved_ratio", type=float, default=0.5, help="cache server reserved capacity ratio after clear" - ) - parser.add_argument( - "--data_type", - type=str, - choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], - default=None, - help="the data type of the model weight", - ) - parser.add_argument("--return_all_prompt_logprobs", action="store_true", help="return all prompt tokens logprobs") - - parser.add_argument("--use_reward_model", action="store_true", help="use reward model") - - parser.add_argument( - "--long_truncation_mode", - type=str, - choices=[None, "head", "center"], - default=None, - help="""use to select the handle way when input token len > max_req_input_len. - None : raise Exception - head : remove some head tokens to make input token len <= max_req_input_len - center : remove some tokens in center loc to make input token len <= max_req_input_len""", - ) - parser.add_argument("--use_tgi_api", action="store_true", help="use tgi input and ouput format") - parser.add_argument( - "--health_monitor", action="store_true", help="check the health of service and restart when error" - ) - parser.add_argument("--metric_gateway", type=str, default=None, help="address for collecting monitoring metrics") - parser.add_argument("--job_name", type=str, default="lightllm", help="job name for monitor") - parser.add_argument( - "--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value" - ) - parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") - parser.add_argument( - "--visual_infer_batch_size", type=int, default=4, help="number of images to process in each inference batch" - ) - parser.add_argument( - "--visual_gpu_ids", nargs="+", type=int, default=[0], help="List of GPU IDs to use, e.g., 0 1 2" - ) - parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") - parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") - parser.add_argument( - "--visual_nccl_ports", - nargs="+", - type=int, - default=[29500], - help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", - ) - parser.add_argument( - "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" - ) - parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") - parser.add_argument( - "--graph_max_batch_size", - type=int, - default=16, - help="""Maximum batch size that can be captured by the cuda graph for decodign stage. - The default value is 8. It will turn into eagar mode if encounters a larger value.""", - ) - parser.add_argument( - "--graph_max_len_in_batch", - type=int, - default=8192, - help="""Maximum sequence length that can be captured by the cuda graph for decodign stage. - The default value is 8192. It will turn into eagar mode if encounters a larger value. """, - ) - return parser - - def main(): parser = make_argument_parser() global args From d9e3ba2ec41e4830f85026301e8b0dd1f8a16f75 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Fri, 1 Nov 2024 17:55:35 +0800 Subject: [PATCH 5/6] update lightllm version to 3.0.0 (#603) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4c83af043..2a25c6cf4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ package_data = {"lightllm": ["models/deepseek2/layer_infer/vllm_configs/*.json"]} setup( name="lightllm", - version="2.0.0", + version="3.0.0", packages=find_packages(exclude=("build", "include", "test", "dist", "docs", "benchmarks", "lightllm.egg-info")), author="model toolchain", author_email="", From 06afb4a635707f009fea9e8f2fa4df9a209edc6b Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:26:13 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E3=80=90Feature=E3=80=91PD=20Mode=20Suppor?= =?UTF-8?q?t=20(#607)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: wangzaijun --- lightllm/common/basemodel/basemodel.py | 86 ++--- lightllm/common/basemodel/cuda_graph.py | 6 + lightllm/common/basemodel/infer_lock.py | 128 +++++++ lightllm/common/mem_manager.py | 32 +- lightllm/common/req_manager.py | 18 +- lightllm/server/api_cli.py | 36 +- lightllm/server/api_lightllm.py | 56 +-- lightllm/server/api_server.py | 359 ++++++------------ lightllm/server/api_start.py | 260 +++++++++++++ lightllm/server/api_tgi.py | 27 +- .../server/detokenization/decode_mode_fix.py | 30 ++ lightllm/server/detokenization/manager.py | 8 + lightllm/server/httpserver/manager.py | 314 +++++++++++---- .../httpserver_for_pd_master/__init__.py | 0 .../httpserver_for_pd_master/manager.py | 312 +++++++++++++++ lightllm/server/io_struct.py | 4 + lightllm/server/metrics/metrics.py | 2 +- lightllm/server/multimodal_params.py | 19 + lightllm/server/pd_io_struct.py | 92 +++++ .../router/dynamic_prompt/radix_cache.py | 17 +- lightllm/server/router/manager.py | 57 ++- .../server/router/model_infer/infer_batch.py | 8 +- .../model_infer/mode_backend/__init__.py | 2 + .../model_infer/mode_backend/base_backend.py | 33 +- .../mode_backend/beamsearch/pre_process.py | 28 +- .../decode_node_impl/__init__.py | 2 + .../decode_node_impl/decode_impl.py | 124 ++++++ .../decode_node_impl/decode_infer_rpyc.py | 154 ++++++++ .../decode_kv_move_manager.py | 258 +++++++++++++ .../decode_node_impl/decode_task_cache.py | 10 + .../decode_node_impl/decode_trans_process.py | 98 +++++ .../decode_node_impl/up_status.py | 67 ++++ .../continues_batch/pre_process.py | 27 +- .../prefill_node_impl/__init__.py | 2 + .../prefill_node_impl/prefill_impl.py | 130 +++++++ .../prefill_node_impl/prefill_infer_rpyc.py | 45 +++ .../prefill_kv_move_manager.py | 217 +++++++++++ .../prefill_node_impl/prefill_task_cache.py | 8 + .../prefill_trans_process.py | 101 +++++ .../mode_backend/splitfuse/pre_process.py | 10 +- .../server/router/model_infer/model_rpc.py | 39 +- lightllm/server/router/req_queue/__init__.py | 3 + .../server/router/req_queue/base_queue.py | 22 +- .../req_queue/continues_batch/beam_impl.py | 18 +- .../router/req_queue/continues_batch/impl.py | 46 ++- .../continues_batch/pd_decode_impl.py | 65 ++++ .../server/router/req_queue/splitfuse/impl.py | 20 +- lightllm/server/router/token_load.py | 81 +++- lightllm/server/sampling_params.py | 18 + lightllm/utils/health_check.py | 16 +- lightllm/utils/net_utils.py | 38 +- lightllm/utils/retry_utils.py | 31 ++ lightllm/utils/statics_utils.py | 13 + requirements.txt | 1 + 54 files changed, 3069 insertions(+), 529 deletions(-) create mode 100644 lightllm/common/basemodel/infer_lock.py create mode 100644 lightllm/server/api_start.py create mode 100644 lightllm/server/detokenization/decode_mode_fix.py create mode 100644 lightllm/server/httpserver_for_pd_master/__init__.py create mode 100644 lightllm/server/httpserver_for_pd_master/manager.py create mode 100644 lightllm/server/pd_io_struct.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/__init__.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_impl.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_infer_rpyc.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_kv_move_manager.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_task_cache.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_trans_process.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/up_status.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/__init__.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_impl.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_infer_rpyc.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_kv_move_manager.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_task_cache.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_trans_process.py create mode 100644 lightllm/server/router/req_queue/continues_batch/pd_decode_impl.py create mode 100644 lightllm/utils/retry_utils.py create mode 100644 lightllm/utils/statics_utils.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5c604699d..4b3f38934 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -17,6 +17,7 @@ from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.infer_lock import g_infer_state_lock logger = init_logger(__name__) @@ -38,6 +39,7 @@ class TpPartBaseModel: splitfuse_infer_state_class = SplitFuseInferStateInfo def __init__(self, kvargs): + self.run_mode = kvargs["run_mode"] self.tp_rank_ = kvargs["tp_rank"] self.world_size_ = kvargs["world_size"] self.weight_dir_ = kvargs["weight_dir"] @@ -67,6 +69,7 @@ def __init__(self, kvargs): self._verify_params() self._init_weights() self._init_mem_manager() + self._init_kv_move_buffer() self._check_mem_size() self._init_req_manager() self._init_infer_layer() @@ -131,6 +134,11 @@ def _init_mem_manager(self): ) return + def _init_kv_move_buffer(self): + # p d 分离的推理模式下才需要做这一步初始化 + if self.run_mode in ["prefill", "decode"]: + self.mem_manager.alloc_kv_move_buffer(self.max_seq_length) + def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size assert self.max_seq_length < self.max_total_token_num @@ -192,6 +200,7 @@ def forward( total_token_num, max_len_in_batch, input_ids: torch.Tensor, + mem_indexes: torch.Tensor, b_req_idx: torch.Tensor, b_start_loc: torch.Tensor, b_seq_len: torch.Tensor, @@ -205,6 +214,7 @@ def forward( total_token_num, max_len_in_batch, input_ids, + mem_indexes, b_req_idx, b_start_loc, b_seq_len, @@ -217,6 +227,7 @@ def forward( total_token_num, max_len_in_batch, input_ids, + mem_indexes, b_req_idx, b_start_loc, b_seq_len, @@ -229,6 +240,7 @@ def _prefill( total_token_num, max_len_in_batch, input_ids, + mem_indexes, b_req_idx, b_start_loc, b_seq_len, @@ -256,22 +268,13 @@ def _prefill( infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager - alloc_mem = self.mem_manager.alloc_contiguous(input_ids.shape[0]) - if alloc_mem is not None: - infer_state.mem_is_contiguous = True - infer_state.mem_index = alloc_mem[0] - infer_state.mem_start = alloc_mem[1] - infer_state.mem_end = alloc_mem[2] - - else: - infer_state.mem_is_contiguous = False - alloc_mem = self.mem_manager.alloc(input_ids.shape[0]) - infer_state.mem_index = alloc_mem - infer_state.kv_buffer = torch.empty( - (input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - dtype=self.data_type, - device="cuda", - ) + infer_state.mem_is_contiguous = False + infer_state.mem_index = mem_indexes + infer_state.kv_buffer = torch.empty( + (input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), + dtype=self.data_type, + device="cuda", + ) init_req_to_token_indexes( self.req_manager.req_to_token_indexs, @@ -292,6 +295,7 @@ def _decode( total_token_num, max_len_in_batch, input_ids, + mem_indexes, b_req_idx, b_start_loc, b_seq_len, @@ -314,23 +318,14 @@ def _decode( # 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致 # 所以不再使用分配连续的mem带来的优化,保证推理流程的一致 - alloc_mem = None if self.graph is not None else self.mem_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.mem_is_contiguous = True - infer_state.mem_index = alloc_mem[0] - infer_state.mem_start = alloc_mem[1] - infer_state.mem_end = alloc_mem[2] - copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) - else: - infer_state.mem_is_contiguous = False - alloc_mem = self.mem_manager.alloc(batch_size) - infer_state.mem_index = alloc_mem - infer_state.kv_buffer = torch.empty( - (batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - dtype=self.data_type, - device="cuda", - ) - copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) + infer_state.mem_is_contiguous = False + infer_state.mem_index = mem_indexes + infer_state.kv_buffer = torch.empty( + (batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), + dtype=self.data_type, + device="cuda", + ) + copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) infer_state.init_some_extra_state(self, input_ids) if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch): @@ -347,6 +342,7 @@ def _decode( def splitfuse_forward( self, input_ids, + mem_indexes, decode_req_num, decode_total_token_num, decode_b_req_idx: torch.Tensor, @@ -384,21 +380,13 @@ def splitfuse_forward( infer_state.req_manager = self.req_manager alloc_size = len(input_ids) - alloc_mem = self.mem_manager.alloc_contiguous(alloc_size) - if alloc_mem is not None: - infer_state.mem_is_contiguous = True - infer_state.mem_index = alloc_mem[0] - infer_state.mem_start = alloc_mem[1] - infer_state.mem_end = alloc_mem[2] - else: - infer_state.mem_is_contiguous = False - alloc_mem = self.mem_manager.alloc(alloc_size) - infer_state.mem_index = alloc_mem - infer_state.kv_buffer = torch.empty( - (alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - dtype=self.data_type, - device="cuda", - ) + infer_state.mem_is_contiguous = False + infer_state.mem_index = mem_indexes + infer_state.kv_buffer = torch.empty( + (alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), + dtype=self.data_type, + device="cuda", + ) # decode 部分 if decode_req_num != 0: @@ -474,6 +462,7 @@ def _check_max_len_infer(self): logger.info("begin check max_len infer") dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda") b_req_idx = self.req_manager.alloc(1).int() + mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)) b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") @@ -484,6 +473,7 @@ def _check_max_len_infer(self): total_token_num, self.batch_max_tokens, dummy_input_ids, + mem_indexes, b_req_idx, b_start_loc, b_seq_len, diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 8220e3c83..5a413424d 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -52,6 +52,7 @@ def warmup(self, model): prefill_input_len = 1 dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") b_req_idx = model.req_manager.alloc(batch_size).int() + mem_indexes = model.mem_manager.alloc(len(dummy_input_ids)) b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") b_start_loc = torch.arange(0, batch_size, dtype=torch.int32, device="cuda") @@ -61,6 +62,7 @@ def warmup(self, model): total_token_num, prefill_input_len, dummy_input_ids, + mem_indexes, b_req_idx, b_start_loc, b_seq_len, @@ -68,6 +70,7 @@ def warmup(self, model): is_prefill=True, multimodal_params=[], ) + mem_indexes = None prob_out = torch.softmax(logics, dim=-1) logics = None predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) @@ -79,16 +82,19 @@ def warmup(self, model): b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") total_token_num += batch_size b_seq_len += 1 + mem_indexes = model.mem_manager.alloc(len(predict_ids)) logics = model.forward( batch_size, total_token_num, prefill_input_len + 1, torch.from_numpy(predict_ids).cuda().reshape(-1), + mem_indexes, b_req_idx, b_start_loc, b_seq_len, is_prefill=False, ) + mem_indexes = None model.mem_manager.free_all() model.req_manager.free_all() # release local tensors diff --git a/lightllm/common/basemodel/infer_lock.py b/lightllm/common/basemodel/infer_lock.py new file mode 100644 index 000000000..4ea994e81 --- /dev/null +++ b/lightllm/common/basemodel/infer_lock.py @@ -0,0 +1,128 @@ +# 这不是一个很好的设计但是不是很好找到更好更简单对架构入侵更小的实现方法。 +# 这个地方声明的锁和计数,主要是用来解决在 PD 分离模式下,kv_move_manager 进程中会出现 +# 通过rpyc调用操作 radix cache 和 mem_manager 中的数据的问题,这可能导致严重的数据同步 +# 问题,主要原因是各个tp的推理进程运行到的位置节点并没有严格的保证,导致radix cache 和 +# mem manager 中的数据出现各个进程间不一致的问题。 +# 下面的实现中,通过一个锁和计数对象, 配合使用的方式,来解决这个问题。 +from dataclasses import dataclass +import numpy as np +import threading +from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray +import torch.distributed as dist +import time +import torch.multiprocessing as mp +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class InferStateLock: + def __init__(self, name): + self.infer_lock = threading.Lock() + # 默认开 128 tp 的空间, 现在应该没什么卡能开这么大的tp 吧 + self.lock_tp_infos = SharedArray(f"{name}_lock_tp_infos", shape=(129,), dtype=np.int64) + self.lock_tp_infos.arr[:] = 0 + self.rank_id = dist.get_rank() + self.world_size = dist.get_world_size() + + def add_cur_mark(self): + self.lock_tp_infos.arr[self.rank_id] += 1 + + def get_cur_mark(self): + return self.lock_tp_infos.arr[self.rank_id] + + def get_max_mark_in_group(self): + return np.max(self.lock_tp_infos.arr[0 : self.world_size]) + + def judge_cur_mark_equal_max_mark_in_group(self): + return self.get_cur_mark() == self.get_max_mark_in_group() + + def judge_mark_in_group_all_same(self): + marks = self.lock_tp_infos.arr[0 : self.world_size] + return bool(np.all(marks == marks[0])) + + def acquire_lock_and_update_cur_mark(self): + self.infer_lock.acquire() + self.add_cur_mark() + + def release_lock(self): + self.infer_lock.release() + + def set_group_wait_mark(self): + if self.rank_id == 0: + self.lock_tp_infos.arr[-1] = 1 + + def unset_group_wait_mark(self): + if self.rank_id == 0: + self.lock_tp_infos.arr[-1] = 0 + + def get_group_wait_mark(self): + return self.lock_tp_infos.arr[-1] + + +@dataclass +class G_Infer_Lock: + obj: InferStateLock = None + + def acquire(self): + if self.obj is not None: + # 当遇到有同步请求的时候,同时自己的mark已经是最大的mark的时候,就在这里休眠, + # 不去竞争锁, 因为 wait_mark == 1 的时候, 说明wait_get_locks被调用,有人 + # 在申请同步点操作 + while self.obj.get_group_wait_mark() == 1 and self.obj.judge_cur_mark_equal_max_mark_in_group(): + time.sleep(0) + + self.obj.acquire_lock_and_update_cur_mark() + + def release(self): + if self.obj is not None: + self.obj.release_lock() + + +# 后续由 backend 对象来对obj进行初始化赋值,方便进行全局调用 +g_infer_state_lock = G_Infer_Lock() + + +# 下面两个函数需要配对使用 +def acquire_lock_until_ready(nccl_group): + g_infer_state_lock.obj.set_group_wait_mark() + while True: + g_infer_state_lock.obj.infer_lock.acquire() + dist.barrier(nccl_group) + judge_ans = g_infer_state_lock.obj.judge_mark_in_group_all_same() + dist.barrier(nccl_group) + + if judge_ans is not True: + # 释放锁进行重试 + g_infer_state_lock.obj.infer_lock.release() + time.sleep(0.001) + logger.info("wait get locks sleep 1ms") + else: + break + + g_infer_state_lock.obj.unset_group_wait_mark() + return + + +def release_acquired_lock(): + g_infer_state_lock.obj.infer_lock.release() + + +@dataclass +class G_Router_Lock: + """ + 保护pd分离模式下, 一些数据的操作。 + """ + + obj = None # 进程锁对象 + + def acquire(self): + if self.obj is not None: + self.obj.acquire() + + def release(self): + if self.obj is not None: + self.obj.release() + + +g_router_lock = G_Router_Lock() diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index b77ff1945..46b344455 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -1,6 +1,7 @@ import re import os import torch +from typing import List from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory @@ -63,9 +64,34 @@ def profile_size(self, mem_fraction): return def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = [ - torch.empty((size, 2 * head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) - ] + self.kv_buffer = torch.empty((layer_num, size, 2 * head_num, head_dim), dtype=dtype, device="cuda") + + def alloc_kv_move_buffer(self, max_req_total_len): + """ + pd 分离模式使用的特殊接口 + """ + if isinstance(self, MemoryManager) and type(self) != MemoryManager: + raise NotImplementedError("subclass need reimpl this method") + self.kv_move_buffer = torch.empty( + (1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" + ) + return + + def read_from_layer_buffer(self, token_indexes: List[int], layer_index: int): + move_size = self.kv_buffer.numel() // self.layer_num // self.size * len(token_indexes) + move_buffer = self.kv_move_buffer.view(-1)[0:move_size].view( + 1, len(token_indexes), 2 * self.head_num, self.head_dim + ) + move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :] + return move_buffer + + def get_layer_buffer_by_token_num(self, token_num): + move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num + return self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim) + + def write_to_layer_buffer(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): + self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor + return def _free_buffers(self): self.kv_buffer = None diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index a9b3a3088..a6a927c17 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,10 +1,12 @@ import torch from lightllm.utils.log_utils import init_logger +from .mem_manager import MemoryManager logger = init_logger(__name__) - + + class ReqManager: - def __init__(self, max_request_num, max_sequence_length, mem_manager): + def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryManager): self.req_state = torch.zeros((max_request_num,), dtype=torch.bool, device="cuda") self.req_to_token_indexs = torch.zeros((max_request_num, max_sequence_length), dtype=torch.int32, device="cuda") self.can_use_req_size = max_request_num @@ -12,25 +14,25 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager): def alloc(self, need_size): if need_size > self.can_use_req_size: - logger.error(f'Insufficient requested capacity, remaining {self.can_use_req_size}') + logger.error(f"Insufficient requested capacity, remaining {self.can_use_req_size}") return None - select_index = torch.nonzero(self.req_state==0).reshape(-1)[:need_size] + select_index = torch.nonzero(self.req_state == 0).reshape(-1)[:need_size] self.req_state[select_index] = 1 self.can_use_req_size -= len(select_index) return select_index - + def free(self, free_req_index, free_token_index): self.can_use_req_size += len(free_req_index) self.req_state[free_req_index] = 0 if self.can_use_req_size == len(self.req_state): logger.debug(f"freed all request size {self.can_use_req_size}") self.mem_manager.free(free_token_index) - + def free_req(self, free_req_index): - self.can_use_req_size +=1 + self.can_use_req_size += 1 self.req_state[free_req_index] = 0 return - + def free_token(self, free_token_index): self.mem_manager.free(free_token_index) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index bf0daa12a..f9901bb7f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -3,9 +3,36 @@ def make_argument_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + + parser.add_argument( + "--run_mode", + type=str, + choices=["normal", "prefill", "decode", "pd_master"], + default="normal", + help="set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode", + ) parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--pd_master_ip", + type=str, + default="127.0.0.1", + help="when run_mode set to prefill or decode, you need set this pd_mater_ip", + ) + parser.add_argument( + "--pd_master_port", + type=int, + default=1212, + help="when run_mode set to prefill or decode, you need set this pd_mater_port", + ) + parser.add_argument( + "--pd_decode_rpyc_port", + type=int, + default=42000, + help="p d mode, decode node used for kv move manager rpyc server port", + ) + parser.add_argument( "--model_name", type=str, @@ -60,7 +87,6 @@ def make_argument_parser() -> argparse.ArgumentParser: "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" ) parser.add_argument("--tp", type=int, default=1, help="model tp parral size, the default is 1") - parser.add_argument("--max_req_input_len", type=int, default=2048, help="the max value for req input tokens num") parser.add_argument( "--max_req_total_len", type=int, default=2048 + 1024, help="the max value for req_input_len + req_output_len" ) @@ -145,10 +171,10 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, choices=[None, "head", "center"], default=None, - help="""use to select the handle way when input token len > max_req_input_len. - None : raise Exception - head : remove some head tokens to make input token len <= max_req_input_len - center : remove some tokens in center loc to make input token len <= max_req_input_len""", + help="""use to select the handle way when input_token_len + max_new_tokens > max_req_total_len. + None : raise Exception + head : remove some head tokens to make input_token_len + max_new_tokens <= max_req_total_len + center : remove some tokens in center loc to make input_token_len + max_new_tokens <= max_req_total_len""", ) parser.add_argument("--use_tgi_api", action="store_true", help="use tgi input and ouput format") parser.add_argument( diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 4ed217115..56c48ef59 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -4,10 +4,11 @@ from fastapi.responses import Response, StreamingResponse from .sampling_params import SamplingParams from .multimodal_params import MultimodalParams -import json +from .httpserver.manager import HttpServerManager +import ujson as json -async def lightllm_get_score(request: Request, g_id_gen, httpserver_manager) -> Response: +async def lightllm_get_score(request: Request, httpserver_manager: HttpServerManager) -> Response: request_dict = await request.json() prompt = request_dict.pop("chat") sample_params_dict = {"max_new_tokens": 1} @@ -15,11 +16,7 @@ async def lightllm_get_score(request: Request, g_id_gen, httpserver_manager) -> sampling_params.verify() multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) - multimodal_params.verify_and_preload() - group_request_id = g_id_gen.generate_id() - results_generator = httpserver_manager.generate( - prompt, sampling_params, group_request_id, multimodal_params, request=request - ) + results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request) ret = {} # n === 1 @@ -31,7 +28,7 @@ async def lightllm_get_score(request: Request, g_id_gen, httpserver_manager) -> return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) -async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> Response: +async def lightllm_generate(request: Request, httpserver_manager: HttpServerManager) -> Response: request_dict = await request.json() prompt = request_dict.pop("inputs") @@ -41,12 +38,8 @@ async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> R sampling_params.verify() multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) - multimodal_params.verify_and_preload() - group_request_id = g_id_gen.generate_id() - results_generator = httpserver_manager.generate( - prompt, sampling_params, group_request_id, multimodal_params, request=request - ) + results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request) # Non-streaming case final_output_dict = collections.defaultdict(list) @@ -103,7 +96,7 @@ async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> R return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) -async def lightllm_generate_stream(request: Request, g_id_gen, httpserver_manager) -> Response: +async def lightllm_generate_stream(request: Request, httpserver_manager: HttpServerManager) -> Response: request_dict = await request.json() prompt = request_dict.pop("inputs") @@ -116,12 +109,7 @@ async def lightllm_generate_stream(request: Request, g_id_gen, httpserver_manage multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) - multimodal_params.verify_and_preload() - - group_request_id = g_id_gen.generate_id() - results_generator = httpserver_manager.generate( - prompt, sampling_params, group_request_id, multimodal_params, request=request - ) + results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: @@ -143,10 +131,30 @@ async def stream_results() -> AsyncGenerator[bytes, None]: yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8") - async def abort_request() -> None: - await httpserver_manager.abort(group_request_id) + background_tasks = BackgroundTasks() + return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def lightllm_pd_generate_stream(request: Request, httpserver_manager: HttpServerManager) -> Response: + + request_dict = await request.json() + prompt = request_dict.pop("inputs") + sample_params_dict = request_dict["parameters"] + _ = sample_params_dict.pop("return_details", False) + sampling_params = SamplingParams(**sample_params_dict) + sampling_params.verify() + if sampling_params.best_of != 1: + raise Exception("stream api only support best_of == 1") + + multimodal_params_dict = request_dict.get("multimodal_params", {}) + multimodal_params = MultimodalParams(**multimodal_params_dict) + results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for sub_req_id, request_output, metadata, finish_status in results_generator: + ret = [sub_req_id, request_output, metadata, finish_status.value] + yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8") background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 3018d0799..721d9e44f 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -27,29 +27,22 @@ from .build_prompt import build_prompt asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -import json +import ujson as json from http import HTTPStatus import uuid import multiprocessing as mp -from typing import AsyncGenerator - -from fastapi import BackgroundTasks, FastAPI, Request +from typing import AsyncGenerator, Union +from typing import Callable +from lightllm.server import TokenLoad +from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.responses import Response, StreamingResponse, JSONResponse import uvicorn from .api_cli import make_argument_parser from .sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager -from .detokenization.manager import start_detokenization_process -from .router.manager import start_router_process -from .embed_cache.manager import start_cache_manager -from .metrics.manager import start_metric_manager -from .visualserver.manager import start_visual_process -from .req_id_generator import ReqIDGenerator -from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl -from .api_lightllm import lightllm_generate, lightllm_generate_stream, lightllm_get_score -from lightllm.utils.net_utils import alloc_can_use_network_port -from lightllm.utils.start_utils import start_submodule_processes +from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster +from .api_lightllm import lightllm_get_score, lightllm_pd_generate_stream from .api_models import ( ChatCompletionRequest, @@ -63,34 +56,33 @@ ) from lightllm.utils.log_utils import init_logger -from prometheus_client import generate_latest from lightllm.server.metrics.manager import MetricClient +from dataclasses import dataclass logger = init_logger(__name__) -TIMEOUT_KEEP_ALIVE = 5 # seconds. -g_id_gen = ReqIDGenerator() -app = FastAPI() -server = uvicorn.Server(uvicorn.Config(app)) +@dataclass +class G_Objs: + app: FastAPI = None + server: uvicorn.Server = None + metric_client: MetricClient = None + args: object = None + g_generate_func: Callable = None + g_generate_stream_func: Callable = None + httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None + shared_token_load: TokenLoad = None -isFirst = True -metric_client = None -global args -args = None +g_objs = G_Objs() -def first_set_handle_loop(): - global isFirst - if isFirst: - loop = asyncio.get_event_loop() - loop.create_task(httpserver_manager.handle_loop()) - isFirst = False - return +app = FastAPI() +g_objs.app = app +g_objs.server = uvicorn.Server(uvicorn.Config(app)) def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse: - metric_client.counter_inc("lightllm_request_failure") + g_objs.metric_client.counter_inc("lightllm_request_failure") return JSONResponse({"message": message}, status_code=status_code.value) @@ -109,21 +101,19 @@ def readiness(): @app.get("/get_model_name") @app.post("/get_model_name") def get_model_name(): - global args - return {"model_name": args.model_name} + return {"model_name": g_objs.args.model_name} @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - first_set_handle_loop() if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": return JSONResponse({"message": "Error"}, status_code=404) from lightllm.utils.health_check import health_check - if await health_check(httpserver_manager, g_id_gen, request): + if await health_check(g_objs.args, g_objs.httpserver_manager, request): return JSONResponse({"message": "Ok"}, status_code=200) else: return JSONResponse({"message": "Error"}, status_code=404) @@ -134,11 +124,11 @@ async def token_load(request: Request): return JSONResponse( { # 当前使用token量,估计的负载 - "current_load": float(shared_token_load.get_current_load()), + "current_load": float(g_objs.shared_token_load.get_current_load()), # 朴素估计的负载,简单将当前请求的输入和输出长度想加得到,目前已未使用,其值与dynamic_max_load一样。 - "logical_max_load": float(shared_token_load.get_logical_max_load()), + "logical_max_load": float(g_objs.shared_token_load.get_logical_max_load()), # 动态估计的最大负载,考虑请求中途退出的情况的负载 - "dynamic_max_load": float(shared_token_load.get_dynamic_max_load()), + "dynamic_max_load": float(g_objs.shared_token_load.get_dynamic_max_load()), }, status_code=200, ) @@ -146,9 +136,8 @@ async def token_load(request: Request): @app.post("/generate") async def generate(request: Request) -> Response: - first_set_handle_loop() try: - return await g_generate_func(request, g_id_gen, httpserver_manager) + return await g_objs.g_generate_func(request, g_objs.httpserver_manager) except Exception as e: logger.error("An error occurred: %s", str(e), exc_info=True) return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) @@ -156,9 +145,17 @@ async def generate(request: Request) -> Response: @app.post("/generate_stream") async def generate_stream(request: Request) -> Response: - first_set_handle_loop() try: - return await g_generate_stream_func(request, g_id_gen, httpserver_manager) + return await g_objs.g_generate_stream_func(request, g_objs.httpserver_manager) + except Exception as e: + logger.error("An error occurred: %s", str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + + +@app.post("/pd_generate_stream") +async def pd_generate_stream(request: Request) -> Response: + try: + return await lightllm_pd_generate_stream(request, g_objs.httpserver_manager) except Exception as e: logger.error("An error occurred: %s", str(e), exc_info=True) return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) @@ -166,9 +163,8 @@ async def generate_stream(request: Request) -> Response: @app.post("/get_score") async def get_score(request: Request) -> Response: - first_set_handle_loop() try: - return await lightllm_get_score(request, g_id_gen, httpserver_manager) + return await lightllm_get_score(request, g_objs.httpserver_manager) except Exception as e: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) @@ -185,7 +181,6 @@ async def compat_generate(request: Request) -> Response: @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: - first_set_handle_loop() if request.logit_bias is not None: return create_error_response( @@ -213,11 +208,9 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request) ) sampling_params.verify() multimodal_params = MultimodalParams(images=[]) - multimodal_params.verify_and_preload() - group_request_id = g_id_gen.generate_id() - results_generator = httpserver_manager.generate( - prompt, sampling_params, group_request_id, multimodal_params, request=raw_request + results_generator = g_objs.httpserver_manager.generate( + prompt, sampling_params, multimodal_params, request=raw_request ) # Non-streaming case @@ -228,6 +221,9 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request) prompt_tokens_dict = {} completion_tokens = 0 async for sub_req_id, request_output, metadata, finish_status in results_generator: + from .req_id_generator import convert_sub_id_to_group_id + + group_request_id = convert_sub_id_to_group_id(sub_req_id) count_output_tokens_dict[sub_req_id] += 1 final_output_dict[sub_req_id].append(request_output) if finish_status.is_finished(): @@ -260,7 +256,11 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason = None + from .req_id_generator import convert_sub_id_to_group_id + async for sub_req_id, request_output, metadata, finish_status in results_generator: + group_request_id = convert_sub_id_to_group_id(sub_req_id) + delta_message = DeltaMessage(role="assistant", content=request_output) if finish_status.is_finished(): finish_reason = finish_status.get_finish_reason() @@ -275,12 +275,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: ) yield ("data: " + stream_resp.json(ensure_ascii=False) + "\n\n").encode("utf-8") - async def abort_request() -> None: - await httpserver_manager.abort(group_request_id) - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) @@ -290,19 +285,71 @@ async def tokens(request: Request): try: request_dict = await request.json() prompt = request_dict.pop("text") - return JSONResponse({"ntokens": httpserver_manager.tokens(prompt)}, status_code=200) + return JSONResponse({"ntokens": g_objs.httpserver_manager.tokens(prompt)}, status_code=200) except Exception as e: return create_error_response(HTTPStatus.EXPECTATION_FAILED, f"error: {str(e)}") @app.get("/metrics") async def metrics() -> Response: - data = await metric_client.generate_latest() + data = await g_objs.metric_client.generate_latest() response = Response(data) response.mimetype = "text/plain" return response +@app.websocket("/register_and_keep_alive") +async def register_and_keep_alive(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"Client connected from IP: {client_ip}, Port: {client_port}") + regist_json = json.loads(await websocket.receive_text()) + logger.info(f"recieved regist_json {regist_json}") + await g_objs.httpserver_manager.register_pd(regist_json) + + try: + while True: + try: + # 等待接收消息,设置超时为10秒 + data = await asyncio.wait_for(websocket.receive_text(), timeout=10) + json_data = json.loads(data) + if json_data.get("type") != "heartbeat": + logger.warning(f"recive error messesage {json_data}") + break + + except asyncio.TimeoutError: + logger.error(f"client {regist_json} heartbeat timeout") + break + + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"client {regist_json} has error {str(e)}") + logger.exception(str(e)) + finally: + logger.error(f"client {regist_json} removed") + await g_objs.httpserver_manager.remove_pd(regist_json) + return + + +@app.websocket("/kv_move_status") +async def kv_move_status(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"kv_move_status Client connected from IP: {client_ip}, Port: {client_port}") + try: + while True: + # 等待接收消息,设置超时为10秒 + data = await websocket.receive_text() + json_data = json.loads(data) + from .pd_io_struct import UpKVStatus + + upkv_status = UpKVStatus(**json_data) + await g_objs.httpserver_manager.update_req_status(upkv_status) + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"kv_move_status client {(client_ip, client_port)} has error {str(e)}") + logger.exception(str(e)) + return + + @app.on_event("shutdown") async def shutdown(): logger.info("Received signal to shutdown. Performing graceful shutdown...") @@ -318,193 +365,27 @@ async def shutdown(): for child in children: os.kill(child.pid, signal.SIGKILL) - server.should_exit = True + g_objs.server.should_exit = True return -def main(): - parser = make_argument_parser() - global args - args = parser.parse_args() - - global g_generate_func - global g_generate_stream_func - if args.use_tgi_api: - g_generate_func = tgi_generate_impl - g_generate_stream_func = tgi_generate_stream_impl - else: - g_generate_func = lightllm_generate - g_generate_stream_func = lightllm_generate_stream - - logger.info(f"use tgi api: {args.use_tgi_api}") - - assert args.max_req_input_len < args.max_req_total_len - assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache" - assert ( - args.mem_fraction > 0 and args.mem_fraction < 1 - ), f"Invalid mem_fraction {args.mem_fraction}, The expected value is between 0 and 1." - - # splitfuse_mode 和 cuda_graph 不能同时开启 - if args.splitfuse_mode: - assert args.disable_cudagraph - - # 这些模式不能同时设置。 - assert [ - args.splitfuse_mode, - args.beam_mode, - args.diverse_mode, - args.token_healing_mode, - args.use_reward_model, - args.return_all_prompt_logprobs, - args.first_token_constraint_mode, - ].count(True) <= 1 - # 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。 - if args.use_dynamic_prompt_cache: - assert args.beam_mode is False - assert args.token_healing_mode is False - - # 部分模式还不能支持与高级动态调度算法协同,to do. - if args.beam_mode or args.diverse_mode: - assert args.router_token_ratio == 0.0 - - # 检查GPU数量是否足够 - total_required_gpus = args.visual_dp * args.visual_tp - if len(args.visual_gpu_ids) < total_required_gpus: - raise ValueError( - f"Not enough GPUs specified. You need at least {total_required_gpus}, but got {len(args.visual_gpu_ids)}." - ) - else: - args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] - - # 检查visual_nccl_port数量是否足够 - if len(args.visual_nccl_ports) < args.visual_dp: - raise ValueError( - f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " - f"but got ({len(args.visual_nccl_ports)})." - ) - else: - args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] - - if not args.splitfuse_mode: - # 普通模式下 - if args.batch_max_tokens is None: - args.batch_max_tokens = args.max_req_total_len - else: - assert args.batch_max_tokens >= args.max_req_total_len, "batch_max_tokens must >= max_req_total_len" - else: - # splitfuse 模式下 - # assert args.batch_max_tokens is not None, "need to set by yourself" - if args.batch_max_tokens is None: - args.batch_max_tokens = min(args.max_req_total_len, 16 * args.splitfuse_block_size) - - assert ( - args.batch_max_tokens > args.splitfuse_block_size - ), "splitfuse_mode, batch_max_tokens must >= splitfuse_block_size" - - # help to manage data stored on Ceph - if "s3://" in args.model_dir: - from lightllm.utils.petrel_helper import s3_model_prepare - - s3_model_prepare(args.model_dir) - - # 如果args.eos_id 是 None, 从 config.json 中读取 eos_token_id 相关的信息,赋值给 args - if args.eos_id is None: - from lightllm.utils.config_utils import get_eos_token_ids - - args.eos_id = get_eos_token_ids(args.model_dir) - - if args.data_type is None: - from lightllm.utils.config_utils import get_dtype - - args.data_type = get_dtype(args.model_dir) - assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] - - logger.info(f"all start args:{args}") - - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port] - can_use_ports = alloc_can_use_network_port( - num=6 + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports - ) - router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6] - model_rpc_ports = can_use_ports[6 : 6 + args.tp] - can_use_ports = can_use_ports[6 + args.tp :] - - visual_model_tp_ports = [] - for _ in range(args.visual_dp): - tp_ports_for_dp = can_use_ports[0 : args.visual_tp] - can_use_ports = can_use_ports[args.visual_tp :] - visual_model_tp_ports.append(tp_ports_for_dp) - - if args.enable_multimodal: - start_submodule_processes( - start_funcs=[ - start_cache_manager, - ], - start_args=[(cache_port, args)], - ) - start_submodule_processes( - start_funcs=[ - start_visual_process, - ], - start_args=[ - (args, router_port, visual_port, cache_port, visual_model_tp_ports), - ], - ) - - start_submodule_processes( - start_funcs=[ - start_metric_manager, - ], - start_args=[(metric_port, args)], - ) - global metric_client - metric_client = MetricClient(metric_port) - - global httpserver_manager - httpserver_manager = HttpServerManager( - args, - router_port=router_port, - cache_port=cache_port, - visual_port=visual_port, - httpserver_port=httpserver_port, - enable_multimodal=args.enable_multimodal, - metric_port=metric_port, - ) - - start_submodule_processes( - start_funcs=[start_router_process, start_detokenization_process], - start_args=[ - (args, router_port, detokenization_port, model_rpc_ports, metric_port), - (args, detokenization_port, httpserver_port), - ], - ) - if "s3://" in args.model_dir: - from lightllm.utils.petrel_helper import s3_model_clear - - s3_model_clear(args.model_dir) - - if args.health_monitor: - from lightllm.server.health_monitor.manager import start_health_check_process - - start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)]) - - # 共享变量,用于获取router端调度分析得到的机器负载信息 - from lightllm.server import TokenLoad - - global shared_token_load - shared_token_load = TokenLoad(f"{str(args.nccl_port)}_shared_token_load") - - server.install_signal_handlers() - uvicorn.run( - app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - loop="uvloop", - ) +@app.on_event("startup") +async def startup_event(): + logger.info("server start up") + loop = asyncio.get_event_loop() + loop.create_task(g_objs.httpserver_manager.handle_loop()) + logger.info("server start up ok") + return if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn"), # this code will not be ok for settings to fork to subprocess - main() + parser = make_argument_parser() + args = parser.parse_args() + g_objs.args = args + from .api_start import normal_or_p_d_start, pd_master_start + + if args.run_mode == "pd_master": + pd_master_start(g_objs) + else: + normal_or_p_d_start(g_objs) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py new file mode 100644 index 000000000..1cb9de4cc --- /dev/null +++ b/lightllm/server/api_start.py @@ -0,0 +1,260 @@ +import uvicorn +import uuid +from lightllm.server.metrics.manager import MetricClient +from lightllm.server import TokenLoad +from .api_lightllm import lightllm_generate, lightllm_generate_stream +from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl +from lightllm.utils.net_utils import alloc_can_use_network_port +from lightllm.utils.start_utils import start_submodule_processes +from .metrics.manager import start_metric_manager +from .embed_cache.manager import start_cache_manager +from .visualserver.manager import start_visual_process +from lightllm.utils.log_utils import init_logger +from .detokenization.manager import start_detokenization_process +from .router.manager import start_router_process +from .httpserver.manager import HttpServerManager +from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster + +logger = init_logger(__name__) + + +def normal_or_p_d_start(g_objs): + from .api_server import G_Objs + + g_objs: G_Objs = g_objs + args = g_objs.args + + if args.run_mode not in ["normal", "prefill", "decode"]: + return + + if args.use_tgi_api: + g_objs.g_generate_func = tgi_generate_impl + g_objs.g_generate_stream_func = tgi_generate_stream_impl + else: + g_objs.g_generate_func = lightllm_generate + g_objs.g_generate_stream_func = lightllm_generate_stream + + logger.info(f"use tgi api: {args.use_tgi_api}") + + assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache" + assert ( + args.mem_fraction > 0 and args.mem_fraction < 1 + ), f"Invalid mem_fraction {args.mem_fraction}, The expected value is between 0 and 1." + + # splitfuse_mode 和 cuda_graph 不能同时开启 + if args.splitfuse_mode: + assert args.disable_cudagraph + + # 这些模式不能同时设置。 + assert [ + args.splitfuse_mode, + args.beam_mode, + args.diverse_mode, + args.token_healing_mode, + args.use_reward_model, + args.return_all_prompt_logprobs, + args.first_token_constraint_mode, + ].count(True) <= 1 + # 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。 + if args.use_dynamic_prompt_cache: + assert args.beam_mode is False + assert args.token_healing_mode is False + + # 部分模式还不能支持与高级动态调度算法协同,to do. + if args.beam_mode or args.diverse_mode: + assert args.router_token_ratio == 0.0 + + # 检查GPU数量是否足够 + total_required_gpus = args.visual_dp * args.visual_tp + if len(args.visual_gpu_ids) < total_required_gpus: + raise ValueError( + f"Not enough GPUs specified. You need at least {total_required_gpus}, but got {len(args.visual_gpu_ids)}." + ) + else: + args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] + + # 检查visual_nccl_port数量是否足够 + if len(args.visual_nccl_ports) < args.visual_dp: + raise ValueError( + f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " + f"but got ({len(args.visual_nccl_ports)})." + ) + else: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + + if not args.splitfuse_mode: + # 普通模式下 + if args.batch_max_tokens is None: + args.batch_max_tokens = args.max_req_total_len + else: + assert args.batch_max_tokens >= args.max_req_total_len, "batch_max_tokens must >= max_req_total_len" + else: + # splitfuse 模式下 + # assert args.batch_max_tokens is not None, "need to set by yourself" + if args.batch_max_tokens is None: + args.batch_max_tokens = min(args.max_req_total_len, 16 * args.splitfuse_block_size) + + assert ( + args.batch_max_tokens > args.splitfuse_block_size + ), "splitfuse_mode, batch_max_tokens must >= splitfuse_block_size" + + # help to manage data stored on Ceph + if "s3://" in args.model_dir: + from lightllm.utils.petrel_helper import s3_model_prepare + + s3_model_prepare(args.model_dir) + + # 如果args.eos_id 是 None, 从 config.json 中读取 eos_token_id 相关的信息,赋值给 args + if args.eos_id is None: + from lightllm.utils.config_utils import get_eos_token_ids + + args.eos_id = get_eos_token_ids(args.model_dir) + + if args.data_type is None: + from lightllm.utils.config_utils import get_dtype + + args.data_type = get_dtype(args.model_dir) + assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + + already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] + can_use_ports = alloc_can_use_network_port( + num=6 + args.tp + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + ) + router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6] + model_rpc_ports = can_use_ports[6 : 6 + args.tp] + can_use_ports = can_use_ports[6 + args.tp :] + + visual_model_tp_ports = [] + for _ in range(args.visual_dp): + tp_ports_for_dp = can_use_ports[0 : args.visual_tp] + can_use_ports = can_use_ports[args.visual_tp :] + visual_model_tp_ports.append(tp_ports_for_dp) + + # 申请在 p d 分离模式下,会用的端口 + args.pd_tp_infer_rpyc_ports = can_use_ports[0 : args.tp] + # p d 分离模式下用于标识节点的id + args.pd_node_id = str(uuid.uuid4()) + # p 节点用来建立torch kv 传输分布组的可用端口范围 + args.pd_p_allowed_port_min = 20000 + args.pd_p_allowed_port_max = 30000 + + # p d 分离模式下,decode节点的调度间隙是0 + if args.run_mode == "decode": + args.router_max_wait_tokens = 0 + + logger.info(f"all start args:{args}") + + if args.enable_multimodal: + start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(cache_port, args)], + ) + start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, router_port, visual_port, cache_port, visual_model_tp_ports), + ], + ) + + start_submodule_processes( + start_funcs=[ + start_metric_manager, + ], + start_args=[(metric_port, args)], + ) + + g_objs.metric_client = MetricClient(metric_port) + + g_objs.httpserver_manager = HttpServerManager( + args, + router_port=router_port, + cache_port=cache_port, + visual_port=visual_port, + httpserver_port=httpserver_port, + enable_multimodal=args.enable_multimodal, + metric_port=metric_port, + ) + + start_submodule_processes( + start_funcs=[start_router_process, start_detokenization_process], + start_args=[ + (args, router_port, detokenization_port, model_rpc_ports, metric_port), + (args, detokenization_port, httpserver_port), + ], + ) + if "s3://" in args.model_dir: + from lightllm.utils.petrel_helper import s3_model_clear + + s3_model_clear(args.model_dir) + + if args.health_monitor: + from lightllm.server.health_monitor.manager import start_health_check_process + + start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)]) + + g_objs.shared_token_load = TokenLoad(f"{str(args.nccl_port)}_shared_token_load", 1) + + g_objs.server.install_signal_handlers() + uvicorn.run( + g_objs.app, + host=args.host, + port=args.port, + log_level="info", + timeout_keep_alive=5, + loop="uvloop", + ) + + +def pd_master_start(g_objs): + from .api_server import G_Objs + + g_objs: G_Objs = g_objs + args = g_objs.args + + if args.run_mode != "pd_master": + return + + if args.use_tgi_api: + g_objs.g_generate_func = tgi_generate_impl + g_objs.g_generate_stream_func = tgi_generate_stream_impl + else: + g_objs.g_generate_func = lightllm_generate + g_objs.g_generate_stream_func = lightllm_generate_stream + + logger.info(f"use tgi api: {args.use_tgi_api}") + logger.info(f"all start args:{args}") + + can_use_ports = alloc_can_use_network_port(num=1, used_nccl_ports=[args.nccl_port, args.port]) + metric_port = can_use_ports[0] + + start_submodule_processes( + start_funcs=[ + start_metric_manager, + ], + start_args=[(metric_port, args)], + ) + + g_objs.metric_client = MetricClient(metric_port) + g_objs.httpserver_manager = HttpServerManagerForPDMaster( + args, + metric_port=metric_port, + ) + + if args.health_monitor: + from lightllm.server.health_monitor.manager import start_health_check_process + + start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)]) + + g_objs.server.install_signal_handlers() + uvicorn.run( + g_objs.app, + host=args.host, + port=args.port, + log_level="info", + timeout_keep_alive=5, + loop="uvloop", + ) diff --git a/lightllm/server/api_tgi.py b/lightllm/server/api_tgi.py index 8bf9ad7ec..23bae4349 100755 --- a/lightllm/server/api_tgi.py +++ b/lightllm/server/api_tgi.py @@ -5,6 +5,7 @@ from fastapi.encoders import jsonable_encoder from .sampling_params import SamplingParams from .multimodal_params import MultimodalParams +from .httpserver.manager import HttpServerManager import json @@ -51,7 +52,7 @@ def format_tgi_params(params): return params -async def tgi_generate_impl(request: Request, g_id_gen, httpserver_manager) -> Response: +async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerManager) -> Response: request_dict = await request.json() prompt = request_dict.pop("inputs") @@ -61,12 +62,8 @@ async def tgi_generate_impl(request: Request, g_id_gen, httpserver_manager) -> R sampling_params.verify() multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) - multimodal_params.verify_and_preload() - group_request_id = g_id_gen.generate_id() - results_generator = httpserver_manager.generate( - prompt, sampling_params, group_request_id, multimodal_params, request=request - ) + results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request) # Non-streaming case final_output_dict = collections.defaultdict(list) @@ -77,11 +74,6 @@ async def tgi_generate_impl(request: Request, g_id_gen, httpserver_manager) -> R prompt_token_ids = None is_first_metadata = True async for sub_req_id, request_output, metadata, finish_status in results_generator: - if await request.is_disconnected(): - # Abort the request if the client disconnects. - await httpserver_manager.abort(group_request_id) - return Response(status_code=499) - # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids if is_first_metadata: @@ -124,7 +116,7 @@ async def tgi_generate_impl(request: Request, g_id_gen, httpserver_manager) -> R return JSONResponse(content=json_compatible_item_data) -async def tgi_generate_stream_impl(request: Request, g_id_gen, httpserver_manager) -> Response: +async def tgi_generate_stream_impl(request: Request, httpserver_manager: HttpServerManager) -> Response: request_dict = await request.json() prompt = request_dict.pop("inputs") @@ -136,12 +128,8 @@ async def tgi_generate_stream_impl(request: Request, g_id_gen, httpserver_manage raise Exception("stream api only support best_of == 1") multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) - multimodal_params.verify_and_preload() - group_request_id = g_id_gen.generate_id() - results_generator = httpserver_manager.generate( - prompt, sampling_params, group_request_id, multimodal_params, request=request - ) + results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: @@ -171,10 +159,5 @@ async def stream_results() -> AsyncGenerator[bytes, None]: yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8") - async def abort_request() -> None: - await httpserver_manager.abort(group_request_id) - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) diff --git a/lightllm/server/detokenization/decode_mode_fix.py b/lightllm/server/detokenization/decode_mode_fix.py new file mode 100644 index 000000000..5f1265f07 --- /dev/null +++ b/lightllm/server/detokenization/decode_mode_fix.py @@ -0,0 +1,30 @@ +""" +p d 分离模式下, 对于到达的请求,需要将输入的prompt_ids 中的最后一个id,提前处理,然后移入到outputs中 +这是 p d 分离模式下,decode 节点的特殊处理点。 +""" +from ..io_struct import ReqDetokenizationState +from .decode import decode_token + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def decode_mode_fix(req_out: ReqDetokenizationState, tokenizer, eos_id): + new_token_id = req_out.prompt_ids[-1] + req_out.prompt_ids = req_out.prompt_ids[0:-1] + req_out.output_ids.append(new_token_id) + + out_text = decode_token( + tokenizer, + req_out, + new_token_id, + eos_id, + ) + + if out_text.endswith("\ufffd"): + pass + else: + req_out.output_str = out_text + + return req_out diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 1607a1895..a7e09650e 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -9,6 +9,7 @@ from ..req_id_generator import convert_sub_id_to_group_id from typing import Union from .decode import decode_token +from .decode_mode_fix import decode_mode_fix from ..tokenizer import get_tokenizer import traceback @@ -21,6 +22,7 @@ class DeTokenizationManager: def __init__( self, + args, eos_id, model_weightdir, tokenizor_mode, @@ -28,6 +30,7 @@ def __init__( httpserver_port, trust_remote_code, ): + self.args = args context = zmq.asyncio.Context(2) self.recv_from_router = context.socket(zmq.PULL) self.recv_from_router.bind(f"tcp://127.0.0.1:{detokenization_port}") @@ -40,6 +43,7 @@ def __init__( self.req_id_to_out = {} self.eos_id = eos_id self._init_get_token_id_to_token_str() + self.is_decode_mode = self.args.run_mode == "decode" def _init_get_token_id_to_token_str(self): self.token_id_to_token = {token_id: token for token, token_id in self.tokenizer.get_vocab().items()} @@ -55,6 +59,9 @@ async def handle_loop(self): recv_obj, (BatchTokenIdOut, ReqDetokenizationState, AbortReq) ), f"type is not right {type(recv_obj)}" if isinstance(recv_obj, ReqDetokenizationState): + if self.is_decode_mode: + recv_obj = decode_mode_fix(recv_obj, self.tokenizer, self.eos_id) + # 将解序列对象复制 best_of 份, 并为其生成请求id for delta_id in range(recv_obj.best_of): recv_obj.request_id = recv_obj.group_req_id + delta_id @@ -129,6 +136,7 @@ def start_detokenization_process(args, detokenization_port, httpserver_port, pip try: router = DeTokenizationManager( + args, args.eos_id, args.model_dir, args.tokenizer_mode, diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 5bbab6ff2..dc1a8f3eb 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -7,15 +7,24 @@ import time import hashlib import datetime +import websockets +import ujson as json asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from typing import Union, List, Tuple, Dict from ..tokenizer import get_tokenizer from ..io_struct import BatchStrOut, AbortReq, FinishStatus +from ..pd_io_struct import NodeRole from ..embed_cache.utils import get_shm_name_data, create_shm from ..req_id_generator import convert_sub_id_to_group_id from ..sampling_params import SamplingParams +from ..multimodal_params import MultimodalParams +from ..req_id_generator import ReqIDGenerator +from fastapi import Request from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient +from lightllm.utils.statics_utils import MovingAverage +from lightllm.utils.net_utils import get_hostname_ip logger = init_logger(__name__) @@ -47,11 +56,15 @@ def __init__( self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) - self.req_id_to_out_inf = {} # value type (out_str, metadata, finished, event) + self.req_id_to_out_inf: Dict[int, ReqStatus] = {} # value type (out_str, metadata, finished, event) - self.max_req_input_len = args.max_req_input_len self.max_req_total_len = args.max_req_total_len self.metric_client = MetricClient(metric_port) + + self.pd_mode: NodeRole = NodeRole(self.args.run_mode) + assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL] + self.id_gen = ReqIDGenerator() + self.first_time_costs = MovingAverage() return # connect cache server, calculate md5, alloc resource, return uuid @@ -72,101 +85,207 @@ async def _alloc_resource(self, data, num): await asyncio.sleep(wait_time) wait_time = min(wait_time + 2, 9) - async def _alloc_multimodal_resources(self, multimodal_params): - for img in multimodal_params.images: - record = await self._alloc_resource(img.read(), self.tokenizer.get_image_token_length(img)) - img.uuid = record["id"] - img.token_id = record["token_id"] - img.token_num = record["token_num"] - - async def _release_multimodal_resources(self, multimodal_params): - if multimodal_params is not None: + async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams): + # 只有 P 和 NORMAL 节点需要真的管理多模态资源 + if self.pd_mode.is_P_or_NORMAL(): for img in multimodal_params.images: - if img.uuid is not None: - self.cache_client.root.release(img.uuid) - # 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常 - img.uuid = None - img.token_id = None - img.token_num = None + record = await self._alloc_resource(img.read(), self.tokenizer.get_image_token_length(img)) + img.uuid = record["id"] + img.token_id = record["token_id"] + img.token_num = record["token_num"] + return + + async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): + # 只有 P 和 NORMAL 节点需要真的管理多模态资源 + if self.pd_mode.is_P_or_NORMAL(): + if multimodal_params is not None: + for img in multimodal_params.images: + if img.uuid is not None: + self.cache_client.root.release(img.uuid) + # 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常 + img.uuid = None + img.token_id = None + img.token_num = None + return def tokens(self, prompt): prompt_ids = self.tokenizer.encode(prompt) return len(prompt_ids) async def generate( - self, prompt, sampling_params: SamplingParams, group_request_id, multimodal_params, request=None - ): - # 记录请求到达的相关信息 - if request is not None: - x_request_id = request.headers.get("X-Request-Id", "") - x_session_id = request.headers.get("X-Session-Id", "") - format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( - f"recieved req X-Request-Id:{x_request_id} " - f"X-Session-Id:{x_session_id} start_time:{format_in_time} " - f"lightllm_req_id:{group_request_id} " + self, + prompt: Union[str, List[int]], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + ) -> Tuple[int, str, dict, FinishStatus]: + start_time = time.time() + # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性 + # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置 + if self.pd_mode == NodeRole.NORMAL: + group_request_id = self.id_gen.generate_id() + sampling_params.group_request_id = group_request_id + elif self.pd_mode == NodeRole.P or self.pd_mode == NodeRole.D: + assert sampling_params.group_request_id is not None, "p d mode, group_request_id must be setting" + group_request_id = sampling_params.group_request_id + else: + assert False, "dead code path" + + try: + if self.pd_mode.is_P_or_NORMAL(): + multimodal_params.verify_and_preload() + + # 记录请求到达的相关信息 + await self._log_req_header(request, group_request_id) + # 监控 + self.metric_client.counter_inc("lightllm_request_count") + + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + prompt_ids = await self._encode(prompt, multimodal_params) + prompt_tokens = len(prompt_ids) + # 监控 + self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) + self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) + verify_time_begin = time.time() + prompt_ids = await self._check_and_repair_length(prompt_ids, sampling_params) + verify_time_end = time.time() + self.metric_client.histogram_observe( + "lightllm_request_validation_duration", verify_time_end - verify_time_begin ) - # 监控 - self.metric_client.counter_inc("lightllm_request_count") + req_status = ReqStatus(group_request_id, multimodal_params) + self.req_id_to_out_inf[group_request_id] = req_status - sampling_params.stop_sentences_to_token_ids(self.tokenizer) + # 将请求转发给其他节点 + await self.transfer_to_next_module( + prompt_ids, sampling_params, multimodal_params, group_request_id, start_time + ) - # 统计信息变量 - start_time = time.time() - out_token_counter = 0 - first_token_cost_ms = sys.float_info.max - is_first_token = True + results_generator = self._wait_to_token_package( + start_time, prompt_ids, group_request_id, sampling_params, req_status, request + ) + async for sub_req_id, request_output, metadata, finish_status in results_generator: + yield sub_req_id, request_output, metadata, finish_status - if self.enable_multimodal: - assert len(multimodal_params.images) <= self.args.cache_capacity, "too many images!" - await self._alloc_multimodal_resources(multimodal_params) - prompt_ids = self.tokenizer.encode(prompt, multimodal_params) + except Exception as e: + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + await self.abort(group_request_id) + raise e + return + + async def _log_req_header(self, request: Request, group_request_id: int): + x_request_id = request.headers.get("X-Request-Id", "") + x_session_id = request.headers.get("X-Session-Id", "") + format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f"recieved req X-Request-Id:{x_request_id} " + f"X-Session-Id:{x_session_id} start_time:{format_in_time} " + f"lightllm_req_id:{group_request_id} " + ) + return + + async def _encode(self, prompt: Union[str, List[int]], multimodal_params: MultimodalParams): + if isinstance(prompt, str): + if self.enable_multimodal: + assert len(multimodal_params.images) <= self.args.cache_capacity, "too many images!" + await self._alloc_multimodal_resources(multimodal_params) + prompt_ids = self.tokenizer.encode(prompt, multimodal_params) + else: + prompt_ids = self.tokenizer.encode(prompt) + return prompt_ids + + # 这里的校验对多模态不是很充分, to do + if all(isinstance(e, int) for e in prompt): + if not self.enable_multimodal: + if all(e < self.tokenizer.vocab_size for e in prompt): + return prompt + else: + raise ValueError("prompt List[int] format contain id > vocab_size") + else: + return prompt else: - prompt_ids = self.tokenizer.encode(prompt) + raise ValueError(f"prompt format error, get type{type(prompt)}") + return + + async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: SamplingParams): prompt_tokens = len(prompt_ids) - # 监控 - self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) - self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) - verify_time_begin = time.time() - if prompt_tokens > self.max_req_input_len: + if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len: # use long_truncation_mode to truncate long input len req. if self.args.long_truncation_mode is None: - raise ValueError(f"the input prompt token len {prompt_tokens} is too long > {self.max_req_input_len}") + raise ValueError( + f"the input prompt token len {prompt_tokens} + max_new_tokens \ + {sampling_params.max_new_tokens} > {self.max_req_total_len}" + ) elif self.args.long_truncation_mode == "head": - prompt_ids = prompt_ids[-self.max_req_input_len :] - prompt_tokens = len(prompt_ids) + prompt_ids = prompt_ids[-(self.max_req_total_len - sampling_params.max_new_tokens) :] elif self.args.long_truncation_mode == "center": - prompt_ids = ( - prompt_ids[0 : self.max_req_input_len // 2] - + prompt_ids[-(self.max_req_input_len - self.max_req_input_len // 2) :] - ) + req_input_len = self.max_req_total_len - sampling_params.max_new_tokens + prompt_ids = prompt_ids[0 : req_input_len // 2] + prompt_ids[-(req_input_len - req_input_len // 2) :] prompt_tokens = len(prompt_ids) - assert prompt_tokens == self.max_req_input_len + assert prompt_tokens == req_input_len else: assert False, "error args" - req_total_len = prompt_tokens + sampling_params.max_new_tokens + # last repaired + req_total_len = len(prompt_ids) + sampling_params.max_new_tokens if req_total_len > self.max_req_total_len: raise ValueError( f"the req total len (input len + output len) is too long > max_req_total_len:{self.max_req_total_len}" ) - verify_time_end = time.time() - req_status = ReqStatus(group_request_id, multimodal_params) - event = req_status.event - self.req_id_to_out_inf[group_request_id] = req_status + return prompt_ids - if self.enable_multimodal: - self.send_to_visual.send_pyobj( - (prompt_ids, sampling_params, multimodal_params, group_request_id, start_time) - ) - else: + async def transfer_to_next_module( + self, prompt_ids, sampling_params, multimodal_params, group_request_id, start_time + ): + if self.pd_mode == NodeRole.P: + if self.enable_multimodal: + self.send_to_visual.send_pyobj( + (prompt_ids, sampling_params, multimodal_params, group_request_id, start_time) + ) + else: + self.send_to_router.send_pyobj( + (prompt_ids, sampling_params, multimodal_params, group_request_id, start_time) + ) + return + + if self.pd_mode == NodeRole.D: + # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 self.send_to_router.send_pyobj( - (prompt_ids, sampling_params, multimodal_params, group_request_id, start_time) + (prompt_ids, sampling_params, MultimodalParams(), group_request_id, start_time) ) + return + + if self.pd_mode == NodeRole.NORMAL: + if self.enable_multimodal: + self.send_to_visual.send_pyobj( + (prompt_ids, sampling_params, multimodal_params, group_request_id, start_time) + ) + else: + self.send_to_router.send_pyobj( + (prompt_ids, sampling_params, multimodal_params, group_request_id, start_time) + ) + return + + assert False, "dead code path" + return + + async def _wait_to_token_package( + self, + start_time, + prompt_ids: List[int], + group_request_id: int, + sampling_params: SamplingParams, + req_status: "ReqStatus", + request: Request, + ): + event = req_status.event unfinished_count = sampling_params.best_of + out_token_counter = 0 + first_token_cost_ms = sys.float_info.max + prompt_tokens = len(prompt_ids) + is_first_token = True while True: try: @@ -174,7 +293,7 @@ async def generate( except asyncio.TimeoutError: pass - if request is not None and await request.is_disconnected(): + if await request.is_disconnected(): await self.abort(group_request_id) raise Exception(f"req_id {group_request_id} disconnected") @@ -184,10 +303,18 @@ async def generate( continue for sub_req_id, out_str, metadata, finish_status in req_status.out_token_info_list: + # pd master 节点需要这个做统计信息, 所以放在元数据中返回给 pd master 节点 metadata["prompt_tokens"] = prompt_tokens + # p 节点返回 prompt_ids 信息,防止 d 节点重新 encode + if self.pd_mode == NodeRole.P and is_first_token: + metadata["prompt_ids"] = prompt_ids + + if is_first_token: + first_token_cost_ms = (time.time() - start_time) * 1000 + is_first_token = False + self.first_time_costs.add(first_token_cost_ms) + out_token_counter += 1 - first_token_cost_ms = (time.time() - start_time) * 1000 if is_first_token else first_token_cost_ms - is_first_token = False yield sub_req_id, out_str, metadata, finish_status # 如果有子请求完成,就更新计数 @@ -198,7 +325,7 @@ async def generate( if unfinished_count == 0: try: del self.req_id_to_out_inf[group_request_id] - await self._release_multimodal_resources(multimodal_params) + await self._release_multimodal_resources(req_status.multimodal_params) except: pass total_cost_time_ms = (time.time() - start_time) * 1000 @@ -218,9 +345,6 @@ async def generate( f"prompt_cache_len:{prompt_cache_len} " f"prompt_cache_ratio:{prompt_cache_ratio} " ) - self.metric_client.histogram_observe( - "lightllm_request_validation_duration", verify_time_end - verify_time_begin - ) self.metric_client.histogram_observe( "lightllm_request_inference_duration", total_cost_time_ms / 1000.0 ) @@ -252,6 +376,8 @@ async def abort(self, group_request_id): return async def handle_loop(self): + asyncio.create_task(self.timer_to_pd_master()) + while True: recv_ans: BatchStrOut = await self.recv_from_detokenization.recv_pyobj() assert isinstance(recv_ans, BatchStrOut), f"error recv type {type(recv_ans)}" @@ -270,6 +396,48 @@ async def handle_loop(self): pass return + async def timer_to_pd_master(self): + if self.pd_mode not in [NodeRole.P, NodeRole.D]: + return + + self.host_ip = get_hostname_ip() + if self.host_ip is None: + self.host_ip = self.args.host + + while True: + try: + uri = f"ws://{self.args.pd_master_ip}:{self.args.pd_master_port}/register_and_keep_alive" + async with websockets.connect(uri) as websocket: + args_dict = vars(self.args) + args_dict["host"] = self.host_ip + # 发送注册信息 + regist_json = { + "node_id": self.args.pd_node_id, + "client_ip_port": f"{self.host_ip}:{self.args.port}", + "mode": self.pd_mode.value, + "start_args": args_dict, + } + + await websocket.send(json.dumps(regist_json)) + logger.info(f"Sent registration JSON: {regist_json}") + + log_count = 0 + while True: + heartbeat_message = {"type": "heartbeat"} + await websocket.send(json.dumps(heartbeat_message)) + if log_count % 10 == 0: + logger.info(f"Sent heartbeat: {heartbeat_message}") + log_count += 1 + await asyncio.sleep(3) + if log_count % 5 == 0: + logger.info(f"mean first cost: {self.first_time_costs.average()} ms") + + except Exception as e: + logger.error("connetion to pd_master has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to pd_master") + class ReqStatus: def __init__(self, req_id, multimodal_params) -> None: @@ -277,4 +445,4 @@ def __init__(self, req_id, multimodal_params) -> None: self.multimodal_params = multimodal_params self.lock = asyncio.Lock() self.event = asyncio.Event() - self.out_token_info_list = [] + self.out_token_info_list: List[Tuple[int, str, dict, FinishStatus]] = [] diff --git a/lightllm/server/httpserver_for_pd_master/__init__.py b/lightllm/server/httpserver_for_pd_master/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py new file mode 100644 index 000000000..9203280b6 --- /dev/null +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -0,0 +1,312 @@ +import sys +import zmq +import zmq.asyncio +import asyncio +import uvloop +import rpyc +import time +import hashlib +import datetime +import aiohttp +import ujson as json + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from typing import Union, List, Tuple, Dict +from ..io_struct import FinishStatus +from ..pd_io_struct import PD_Client_Obj, UpKVStatus +from ..sampling_params import SamplingParams +from ..multimodal_params import MultimodalParams +from ..req_id_generator import ReqIDGenerator +from fastapi import Request +from lightllm.utils.log_utils import init_logger +from lightllm.server.metrics.manager import MetricClient +from lightllm.utils.statics_utils import MovingAverage + +logger = init_logger(__name__) + + +class HttpServerManagerForPDMaster: + def __init__( + self, + args, + metric_port, + ): + self.args = args + self.metric_client = MetricClient(metric_port) + self.id_gen = ReqIDGenerator() + self.prefill_nodes: List[PD_Client_Obj] = [] + self.decode_nodes: List[PD_Client_Obj] = [] + self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {} + + self.id_to_event: Dict[int, asyncio.Event] = {} + self.session = None + self.first_time_costs = MovingAverage() + self.create_session_costs = MovingAverage() + return + + async def register_pd(self, pd_info_json): + pd_client = PD_Client_Obj(**pd_info_json) + self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client + if pd_client.mode == "prefill": + self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] + self.prefill_nodes.append(pd_client) + elif pd_client.mode == "decode": + self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] + self.decode_nodes.append(pd_client) + else: + assert False + + logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed") + return + + async def remove_pd(self, pd_info_json): + pd_client = PD_Client_Obj(**pd_info_json) + try: + del self.url_to_pd_nodes[pd_client.client_ip_port] + except: + pass + self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] + self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] + logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") + return + + async def update_req_status(self, upkv_status: UpKVStatus): + try: + event = self.id_to_event[upkv_status.group_request_id] + event.set() + del self.id_to_event[upkv_status.group_request_id] + except: + pass + return + + def tokens(self, prompt: str): + # to do + raise NotImplementedError("tokens is not implements") + + async def select_p_d_node( + self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams + ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + import random + + p_node = random.choice(self.prefill_nodes) + d_node = random.choice(self.decode_nodes) + return p_node, d_node + + async def generate( + self, + prompt: Union[str, List[int]], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + ) -> Tuple[int, str, dict, FinishStatus]: + start_time = time.time() + group_request_id = self.id_gen.generate_id() + try: + sampling_params.group_request_id = group_request_id + # 记录请求到达的相关信息 + await self._log_req_header(request, group_request_id) + # 监控 + self.metric_client.counter_inc("lightllm_request_count") + self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) + + p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) + + results_generator = self._wait_to_token_package( + p_node, + d_node, + start_time, + prompt, + sampling_params, + multimodal_params, + request, + ) + async for sub_req_id, request_output, metadata, finish_status in results_generator: + yield sub_req_id, request_output, metadata, finish_status + finally: + await self.remove_req(group_request_id) + return + + async def _log_req_header(self, request: Request, group_request_id: int): + x_request_id = request.headers.get("X-Request-Id", "") + x_session_id = request.headers.get("X-Session-Id", "") + format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f"recieved req X-Request-Id:{x_request_id} " + f"X-Session-Id:{x_session_id} start_time:{format_in_time} " + f"lightllm_req_id:{group_request_id} " + ) + return + + async def _to_req_info( + self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams + ): + req = { + "inputs": prompt, + "parameters": sampling_params.to_origin_dict(), + "multimodal_params": multimodal_params.to_origin_dict(), + } + return req + + async def fetch_stream( + self, + p_node: PD_Client_Obj, + d_node: PD_Client_Obj, + prompt: Union[str, List[int]], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + ): + group_request_id = sampling_params.group_request_id + event = asyncio.Event() + self.id_to_event[group_request_id] = event + # 初始化连接池 + if self.session is None: + self.session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=2000, verify_ssl=False)) + await self.session.__aenter__() + + d_start_args = d_node.start_args + decode_node_dict = { + "node_id": d_start_args["pd_node_id"], + "ip": d_start_args["host"], + "rpyc_port": d_start_args["pd_decode_rpyc_port"], + "max_new_tokens": sampling_params.max_new_tokens - 1, + } + + try: + old_max_new_tokens = sampling_params.max_new_tokens + sampling_params.max_new_tokens = 1 + sampling_params.move_kv_to_decode_node = decode_node_dict if old_max_new_tokens != 1 else None + + req = await self._to_req_info(prompt, sampling_params, multimodal_params) + create_start_time = time.time() + async with self.session.post(p_node.to_llm_url(), json=req) as response: + self.create_session_costs.add((time.time() - create_start_time) * 1000) + if response.status == 200: + async for line in response.content: + line = line.decode("utf-8").strip() + if line.startswith("data:"): + data = line[len("data:") :].strip() + sub_req_id, request_output, metadata, finish_status = json.loads(data) + if old_max_new_tokens != 1: + finish_status = FinishStatus.NO_FINISH + else: + finish_status = FinishStatus(finish_status) + # 得到 p 节点返回的 prompt_ids 信息 + if metadata.get("prompt_ids", None) is not None: + prompt_ids = metadata.get("prompt_ids") + prompt_ids.append(metadata.get("id")) + yield sub_req_id, request_output, metadata, finish_status + else: + logger.error(f"fetch_stream error: {response.status}") + raise Exception(f"group_req_id {group_request_id} connection error: {response}") + + # 如果只需要一个输出 token,prefill 完就直接结束掉吧 + if old_max_new_tokens == 1: + return + + try: + await asyncio.wait_for(event.wait(), timeout=60) + except asyncio.TimeoutError: + logger.warning(f"group_request_id: {group_request_id} time out err") + raise Exception("server is busy") + # raise Exception(f"group_request_id: {group_request_id} time out err, maybe kv move get questions") + + sampling_params.move_kv_to_decode_node = None + sampling_params.max_new_tokens = old_max_new_tokens - 1 + req = await self._to_req_info(prompt_ids, sampling_params, multimodal_params) + async with self.session.post(d_node.to_llm_url(), json=req) as response: + if response.status == 200: + async for line in response.content: + line = line.decode("utf-8").strip() + if line.startswith("data:"): + data = line[len("data:") :].strip() + sub_req_id, request_output, metadata, finish_status = json.loads(data) + yield sub_req_id, request_output, metadata, FinishStatus(finish_status) + else: + logger.error(f"fetch_stream error: {response.status}") + raise Exception(f"group_req_id {group_request_id} connection error: {response}") + finally: + await self.remove_req(group_request_id) + return + + async def _wait_to_token_package( + self, + p_node: PD_Client_Obj, + d_node: PD_Client_Obj, + start_time: float, + prompt: str, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + ): + out_token_counter = 0 + first_token_cost_ms = sys.float_info.max + group_request_id = sampling_params.group_request_id + unfinished_count = sampling_params.best_of + is_first_token = True + + async for sub_req_id, out_str, metadata, finish_status in self.fetch_stream( + p_node, d_node, prompt, sampling_params, multimodal_params + ): + if await request.is_disconnected(): + await self.abort(group_request_id) + raise Exception(f"req_id {group_request_id} disconnected") + prompt_tokens = metadata["prompt_tokens"] + out_token_counter += 1 + if is_first_token: + first_token_cost_ms = (time.time() - start_time) * 1000 + is_first_token = False + self.first_time_costs.add(first_token_cost_ms) + + yield sub_req_id, out_str, metadata, finish_status + if finish_status.is_finished(): + unfinished_count -= 1 + if unfinished_count == 0: + break + + total_cost_time_ms = (time.time() - start_time) * 1000 + mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter + x_request_id = request.headers.get("X-Request-Id", "") + x_session_id = request.headers.get("X-Session-Id", "") + prompt_cache_len = metadata.pop("prompt_cache_len", 0) + prompt_cache_ratio = prompt_cache_len / prompt_tokens + format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f"X-Request-Id:{x_request_id} " + f"X-Session-Id:{x_session_id} start_time:{format_start_time} " + f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " + f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter} " + f"mean_per_token_cost_time: {mean_per_token_cost_time_ms}ms " + f"prompt_token_num:{prompt_tokens} " + f"prompt_cache_len:{prompt_cache_len} " + f"prompt_cache_ratio:{prompt_cache_ratio} " + ) + self.metric_client.histogram_observe("lightllm_request_inference_duration", total_cost_time_ms / 1000.0) + self.metric_client.histogram_observe( + "lightllm_request_mean_time_per_token_duration", mean_per_token_cost_time_ms / 1000.0 + ) + self.metric_client.histogram_observe("lightllm_request_first_token_duration", first_token_cost_ms / 1000.0) + self.metric_client.histogram_observe("lightllm_request_generated_tokens", out_token_counter) + self.metric_client.counter_inc("lightllm_request_success") + return + + async def abort(self, group_request_id): + logger.warning(f"aborted group_request_id {group_request_id}") + try: + del self.id_to_event[group_request_id] + except: + pass + return + + async def remove_req(self, group_request_id): + try: + del self.id_to_event[group_request_id] + except: + pass + + async def handle_loop(self): + while True: + # 可以做一个定时任务 + await asyncio.sleep(20) + logger.info(f"mean first cost: {self.first_time_costs.average()} ms") + logger.info(f"create_session_costs: {self.create_session_costs.average()} ms") + return diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 6f73bb092..9d717455d 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -1,10 +1,14 @@ import time import asyncio import enum +from dataclasses import dataclass from .sampling_params import SamplingParams from .multimodal_params import MultimodalParams from typing import Dict, List, Optional, Tuple, Union from lightllm.server.req_id_generator import convert_sub_id_to_group_id +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class ReqRunStatus(enum.Enum): diff --git a/lightllm/server/metrics/metrics.py b/lightllm/server/metrics/metrics.py index 42cc5da59..8a2d7bc3e 100644 --- a/lightllm/server/metrics/metrics.py +++ b/lightllm/server/metrics/metrics.py @@ -67,7 +67,7 @@ def init_metrics(self, args): self.create_counter("lightllm_request_failure") self.create_counter("lightllm_batch_inference_count", labelnames=["method"]) - max_req_input_len = args.max_req_input_len + max_req_input_len = args.max_req_total_len input_len_buckets = [max_req_input_len / 100.0 * (i + 1) for i in range(-1, 100)] self.create_histogram("lightllm_request_input_length", input_len_buckets) self.create_histogram("lightllm_cache_length", input_len_buckets) diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index bc78ec853..2c3965eb0 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -26,6 +26,8 @@ def preload(self): try: if self._type == "url": timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) + # 这个地方获取数据有问题,应该修改为异步协程方式获取,否则会阻塞原有的线程 + # to do ret = requests.get(self._data, timeout=timeout) img_data = ret.content elif self._type == "base64": @@ -56,6 +58,15 @@ def to_dict(self): ret["token_num"] = self.token_num return ret + def to_origin_dict(self): + """ + 将内容转换为原始请求的形式,主要用于请求转发 + """ + ret = {} + ret["type"] = self._type + ret["data"] = self._data + return ret + class MultimodalParams: def __init__( @@ -74,3 +85,11 @@ def to_dict(self): ret = {} ret["images"] = [i.to_dict() for i in self.images] return ret + + def to_origin_dict(self): + """ + 将内容转换为原始请求的形式,主要用于请求转发 + """ + ret = {} + ret["images"] = [i.to_origin_dict() for i in self.images] + return ret diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py new file mode 100644 index 000000000..03cb628b7 --- /dev/null +++ b/lightllm/server/pd_io_struct.py @@ -0,0 +1,92 @@ +import enum +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union +from lightllm.server.req_id_generator import convert_sub_id_to_group_id +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +# 节点的行为 +class NodeRole(enum.Enum): + P = "prefill" + D = "decode" + NORMAL = "normal" + PD_MASTER = "pd_master" + + def is_P_or_NORMAL(self): + return (self == NodeRole.P) or (self == NodeRole.NORMAL) + + +@dataclass +class PD_Client_Obj: + node_id: str + client_ip_port: str + mode: str # 只能是 prefill 或者 decode 节点 + start_args: object # 节点的启动参数信息,用于做匹配性的校验,防止运行过程中出现问题。 + + def __post_init__(self): + if self.mode not in ["prefill", "decode"]: + error_info = f"""mode must in ["prefill", "decode"], but get {self.mode}""" + logger.error(error_info) + raise ValueError(error_info) + return + + def to_llm_url(self): + return f"http://{self.client_ip_port}/pd_generate_stream" + + +@dataclass +class UpKVStatus: + type: str = "kv_move_status" + group_request_id: int = None + + def __post_init__(self): + if self.type != "kv_move_status": + error_info = "type only can be 'kv_move_status'" + logger.error(error_info) + raise ValueError(error_info) + + if not isinstance(self.group_request_id, int): + error_info = "group_request_id only can be int" + logger.error(error_info) + raise ValueError(error_info) + return + + +@dataclass +class DecodeNodeInfo: + node_id: str + ip: str + rpyc_port: str + max_new_tokens: int + + +@dataclass +class KVMoveTask: + group_request_id: int + input_tokens: List[int] # 代表输入的token_id 序列 + prefill_token_indexes: List[int] # 在prefill节点上 mem manager kv buffer中的token index + # 在decode节点上 mem manager kv buffer中的token index, 其代表的是真实占用的额外token,并不与prefill_token_indexes 一样长 + decode_token_indexes: List[int] + move_kv_len: int # 因为 prompt cache 的原因,当prefill节点和decode节点沟通后,传输的kv的数量可能少于 prefill_value 的长度 + prefill_node_id: str + decode_node: DecodeNodeInfo + + def __post_init__(self): + if len(self.input_tokens) <= 0: + error_info = "key must len >= 1" + logger.error(error_info) + raise ValueError(error_info) + + def to_prefill_log_info(self): + v_len = None if self.prefill_token_indexes is None else len(self.prefill_token_indexes) + log = f"id: {self.group_request_id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len}" + return log + + def to_decode_log_info(self): + v_len = None if self.decode_token_indexes is None else len(self.decode_token_indexes) + log = f"id: {self.group_request_id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len}" + return log + + def id(self): + return self.group_request_id diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index deb1d722c..9216e43c7 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -133,7 +133,9 @@ def insert(self, key, value=None): if value is None: value = key - assert len(key) == len(value) and len(key) >= 1 + assert len(key) == len(value) # and len(key) >= 1 + if len(key) == 0: + return 0 return self._insert_helper(self.root_node, key, value) def _insert_helper(self, node: TreeNode, key, value): @@ -302,12 +304,23 @@ def clear_tree_nodes(self): self.refed_tokens_num.arr[0] = 0 return - def dec_node_ref_counter(self, node): + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + # 如果减引用的是叶节点,需要先从 evict_tree_set 中移除 + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + while node is not None: if node.ref_counter == 1: self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) node.ref_counter -= 1 node = node.parent + + # 加回。 + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) return def get_refed_tokens_num(self): diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index c7ef6dea1..e53dba811 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -3,11 +3,13 @@ import uuid import uvloop import asyncio +import torch import rpyc asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq import zmq.asyncio +import torch.multiprocessing as mp from typing import Dict, List, Optional from ..sampling_params import SamplingParams from ..io_struct import Req, NormalReq, SplitFuseReq, TokenHealingReq, Batch @@ -26,6 +28,7 @@ from lightllm.server.router.token_load import TokenLoad from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.server.metrics.manager import MetricClient +from lightllm.common.basemodel.infer_lock import g_router_lock logger = init_logger(__name__) @@ -44,7 +47,9 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr self.radix_cache_client = None # 共享变量,用于存储router端调度分析得到的机器负载信息 - self.shared_token_load = TokenLoad(f"{str(args.nccl_port)}_shared_token_load") + self.shared_token_load = TokenLoad(f"{str(args.nccl_port)}_shared_token_load", 1) + self.shared_token_load.set_estimated_peak_token_count(0) + self.shared_token_load.set_frozened_token_count(0) self.shared_token_load.set_current_load(0.0) self.shared_token_load.set_logical_max_load(0.0) self.shared_token_load.set_dynamic_max_load(0.0) @@ -69,13 +74,32 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) self.metric_client = MetricClient(metric_port) + self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] + # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 + # 主要是为了防止调度失误,造成 OOM 等错误 + self.router_lock = mp.Lock() + g_router_lock.obj = self.router_lock return async def wait_to_model_ready(self): # 初始化模型 self.model_rpcs: List[ModelRpcClient] = [] + # 用于 kv move 管理进程 和 推理进程进行tensor信息的交互。 + self.info_queues: List[torch.multiprocessing.Queue] = [ + torch.multiprocessing.Queue() for _ in range(self.world_size) + ] + self.mem_queues: List[torch.multiprocessing.Queue] = [ + torch.multiprocessing.Queue() for _ in range(self.world_size) + ] for rank_id in range(self.world_size): - rpc_model = await start_model_process(port=self.model_rpc_ports[rank_id], world_size=self.world_size) + rpc_model = await start_model_process( + args=self.args, + port=self.model_rpc_ports[rank_id], + world_size=self.world_size, + info_queue=self.info_queues[rank_id], + mem_queue=self.mem_queues[rank_id], + router_lock=self.router_lock, + ) self.model_rpcs.append(rpc_model) init_model_ret = [] @@ -107,6 +131,7 @@ async def wait_to_model_ready(self): "disable_cudagraph": self.args.disable_cudagraph, "mem_fraction": self.args.mem_fraction, "batch_max_tokens": self.args.batch_max_tokens, + "pd_rpyc_port": self.args.pd_tp_infer_rpyc_ports[rank_id], # 非 pd 模式可以不设置 } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) @@ -120,6 +145,23 @@ async def wait_to_model_ready(self): ) self.req_queue = build_req_queue(self.args, self) logger.info(f"use req queue {self.req_queue.__class__.__name__}") + + if self.args.run_mode == "prefill": + # 启动 prefill kv move 管理进程 + from lightllm.server.router.model_infer.mode_backend.continues_batch.prefill_node_impl import ( + start_prefill_kv_move_manager_process, + ) + + start_prefill_kv_move_manager_process(self.args, self.info_queues, self.mem_queues) + + if self.args.run_mode == "decode": + # 启动 decode kv move 管理进程 + from lightllm.server.router.model_infer.mode_backend.continues_batch.decode_node_impl import ( + start_decode_kv_move_manager_process, + ) + + start_decode_kv_move_manager_process(self.args, self.info_queues, self.mem_queues) + return def add_req( @@ -191,9 +233,7 @@ async def loop_for_fwd( f"token used ratio: {token_ratio1} not contain prompt cache tree unrefed tokens\n" f"token used ratio: {token_ratio2} contain prompt cache tree unrefed tokens" ) - self.shared_token_load.set_current_load(token_ratio1) - self.req_queue.update_token_load(self.running_batch) - pass + self.req_queue.update_token_load(self.running_batch, force_update=False) self.stats_tool.print_stats() self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs)) self.metric_client.gauge_set("lightllm_batch_pause_size", len(self.req_queue.pause_req_dict)) @@ -203,8 +243,7 @@ async def loop_for_fwd( int(self.shared_token_load.get_dynamic_max_load() * self.max_total_token_num), ) else: - self.shared_token_load.set_dynamic_max_load(0.0) - self.shared_token_load.set_current_load(0.0) + self.req_queue.update_token_load(self.running_batch, force_update=True) if counter_count % 300 == 0: self.metric_client.gauge_set("lightllm_batch_current_size", 0.0) self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0) @@ -405,6 +444,10 @@ def _update_out_status_to_batch(self, batch: Batch, req_to_out_status): return def _can_decode(self, batch: Batch): + # p d 分离模式下,目前只能使用保守调度,保证请求放入进行decode的时候 + # 显存token肯定是够用的 + if self.is_pd_run_mode: + return True return batch.batch_decode_need_tokens + self.get_used_tokens() <= self.max_total_token_num def _send_to_detokenization_proc(self, batch: Batch, req_ans): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1dd954fb7..501c235ab 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -40,6 +40,7 @@ def __init__( input_penalty: bool = False, regular_constraint: Optional[str] = None, allowed_token_ids: Optional[List[int]] = None, + move_kv_to_decode_node: Optional[bool] = None, ) -> None: self.best_of = best_of self.do_sample = do_sample @@ -62,6 +63,8 @@ def __init__( self.regex_guide = None self.fsm_current_state: int = 0 self.allowed_token_ids = allowed_token_ids + # p d mode use params + self.move_kv_to_decode_node = move_kv_to_decode_node # this check is not very good to placed here. to do... if self.allowed_token_ids is not None: if not all(e < vocab_size for e in self.allowed_token_ids): @@ -313,7 +316,7 @@ def init_batch( value_tensor = value_tensor.long().cuda() mem_manager.add_refs(value_tensor) # 加 refs req_manager.req_to_token_indexs[r_obj.req_idx, 0:ready_cache_len] = value_tensor - r_obj.cur_kv_len = ready_cache_len + r_obj.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64 # 初始化之后 所有请求状态置换为 RUNNING 状态 r_obj.req_status = ReqRunStatus.RUNNING @@ -420,7 +423,8 @@ def pause_reqs(self, pause_reqs: List[str]): req.req_status = pause_way self.request_ids.remove(request_id) if pause_way == ReqRunStatus.PAUSED_AND_OFFLOAD: - self._free_a_req_mem(free_token_index, req) + # 不支持多输出的情况 + self._free_a_req_mem(free_token_index, req, is_group_finished=True) req.cur_kv_len = 0 if len(free_token_index) != 0: diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index d35c459e6..bcbdd023c 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -7,3 +7,5 @@ from .continues_batch.impl_for_token_healing import TokenHealingBackend from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend +from .continues_batch.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode +from .continues_batch.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index cefaa15cf..9c9cff922 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -50,6 +50,8 @@ from lightllm.utils.log_utils import init_logger from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams, requests_mapping +from lightllm.server.router.token_load import TokenLoad +from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock class ModeBackend: @@ -62,6 +64,9 @@ def init_model(self, kvargs): world_size = kvargs["world_size"] self.args = kvargs.get("args", None) + # p d 分离模式下会有特殊的一些初始化, 所以需要传递 + # 模式参数到模型的初始化过程中进行控制 + self.run_mode = "normal" if self.args is None else self.args.run_mode self.is_multimodal = False self.tp_rank = kvargs["rank_id"] self.world_size = kvargs["world_size"] @@ -77,6 +82,10 @@ def init_model(self, kvargs): self.logger = init_logger(__name__) self.weight_dir = kvargs["weight_dir"] + nccl_port_str = str(kvargs["nccl_port"]) + self.shared_token_load = TokenLoad(f"{nccl_port_str}_shared_token_load", 1) + # p d 分离模式,decode节点才会使用的参数 + self.pd_rpyc_port = kvargs.get("pd_rpyc_port", None) max_total_token_num = kvargs["max_total_token_num"] dist.init_process_group( @@ -84,6 +93,14 @@ def init_model(self, kvargs): ) torch.cuda.set_device(self.tp_rank) + # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 + # init_process_group 之后调用 + g_infer_state_lock.obj = InferStateLock(name=nccl_port_str) + self.infer_state_lock = g_infer_state_lock + # 防止InferStateLock 中的全局共享信息被重复异常初始化,导致同步异常的问题。 + # 所以做一次barrier等待 + dist.barrier() + model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) model_kvargs = { @@ -104,6 +121,7 @@ def init_model(self, kvargs): "disable_cudagraph": kvargs.get("disable_cudagraph", False), "mem_fraction": kvargs.get("mem_fraction", 0.9), "batch_max_tokens": kvargs.get("batch_max_tokens", None), + "run_mode": self.run_mode, } is_weight_only_quant = any("w6a16" in mode_ or "w8a16" in mode_ or "w4a16" in mode_ for mode_ in self.mode) @@ -219,11 +237,8 @@ def init_model(self, kvargs): self.is_multimodal = True else: raise Exception(f"can not support {self.model_type} now") - except Exception as e: - self.logger.error(f"load model error: {str(e)} {e} {type(e)}") - import traceback - - traceback.print_exc() + except BaseException as e: + self.logger.exception(str(e)) raise e set_random_seed(2147483647) @@ -256,6 +271,7 @@ def decode_batch(self, batch_id): # @calculate_time(show=True, min_cost_ms=0.1) def add_batch(self, batch_id, reqs): + g_infer_state_lock.acquire() batch_data = InferBatch.init_batch( batch_id, reqs, @@ -266,6 +282,7 @@ def add_batch(self, batch_id, reqs): self.radix_cache, ) self.cache[batch_id] = batch_data + g_infer_state_lock.release() # 将更新后的状态返回给调用方用于router中请求的状态 ans = {} @@ -284,17 +301,21 @@ def add_batch(self, batch_id, reqs): # @calculate_time(show=True, min_cost_ms=0.1) def filter_batch(self, batch_id, req_id_list, finished_req_id_list): + g_infer_state_lock.acquire() batch = self.cache.pop(batch_id) filter_batch = batch.filter(req_id_list, finished_req_id_list) del batch self.cache[batch_id] = filter_batch + g_infer_state_lock.release() return def pause_reqs(self, batch_id, req_list): + g_infer_state_lock.acquire() batch1 = self.cache.pop(batch_id) batch2 = batch1.pause_reqs(req_list) self.cache[batch_id] = batch2 del batch1 + g_infer_state_lock.release() return # @calculate_time(show=True, min_cost_ms=0.1) @@ -309,7 +330,9 @@ def merge_batch(self, batch_id1, batch_id2): # @calculate_time(show=True, min_cost_ms=10) def remove_batch(self, batch_id): + g_infer_state_lock.acquire() batch = self.cache.pop(batch_id) batch.free_self() del batch + g_infer_state_lock.release() return diff --git a/lightllm/server/router/model_infer/mode_backend/beamsearch/pre_process.py b/lightllm/server/router/model_infer/mode_backend/beamsearch/pre_process.py index f9e82538d..53d4b6101 100644 --- a/lightllm/server/router/model_infer/mode_backend/beamsearch/pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/beamsearch/pre_process.py @@ -1,6 +1,12 @@ import torch import numpy as np -from lightllm.server.router.model_infer.infer_batch import requests_mapping, group_mapping, InferReqGroup, InferReq, InferBatch +from lightllm.server.router.model_infer.infer_batch import ( + requests_mapping, + group_mapping, + InferReqGroup, + InferReq, + InferBatch, +) from lightllm.server.io_struct import ReqRunStatus from lightllm.utils.infer_utils import calculate_time from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache @@ -49,11 +55,18 @@ def prepare_prefill_inputs(batch: InferBatch, radix_cache: RadixCache, is_multim nopad_b_start_loc = torch.tensor(nopad_b_start_loc, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") + + # dynamic prompt cache 准备 token + if radix_cache is not None: + radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = batch.req_manager.mem_manager.alloc(input_ids.shape[0]) + kwargs = { "batch_size": nopad_b_seq_len.shape[0], "total_token_num": nopad_total_token_num, "max_len_in_batch": nopad_max_len_in_batch, "input_ids": input_ids, + "mem_indexes": mem_indexes, "b_req_idx": nopad_b_req_idx, "b_start_loc": nopad_b_start_loc, "b_seq_len": nopad_b_seq_len, @@ -63,10 +76,6 @@ def prepare_prefill_inputs(batch: InferBatch, radix_cache: RadixCache, is_multim if is_multimodal: kwargs["multimodal_params"] = batch_multimodal_params - # dynamic prompt cache 准备 token - if radix_cache is not None: - radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - return kwargs, run_reqs_group @@ -103,18 +112,21 @@ def prepare_decode_inputs(batch: InferBatch, radix_cache: RadixCache): nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") nopad_b_start_loc = torch.tensor(nopad_b_start_loc, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + + # dynamic prompt cache 准备 token + if radix_cache is not None: + radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = batch.req_manager.mem_manager.alloc(input_ids.shape[0]) kwargs = { "batch_size": nopad_b_seq_len.shape[0], "total_token_num": nopad_total_token_num, "max_len_in_batch": nopad_max_len_in_batch, "input_ids": input_ids, + "mem_indexes": mem_indexes, "b_req_idx": nopad_b_req_idx, "b_start_loc": nopad_b_start_loc, "b_seq_len": nopad_b_seq_len, "is_prefill": False, } - # dynamic prompt cache 准备 token - if radix_cache is not None: - radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) return kwargs, run_req_groups diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/__init__.py new file mode 100644 index 000000000..4b40544fe --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/__init__.py @@ -0,0 +1,2 @@ +from .decode_kv_move_manager import start_decode_kv_move_manager_process +from .decode_trans_process import start_decode_trans_process diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_impl.py new file mode 100644 index 000000000..cf62dbc78 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_impl.py @@ -0,0 +1,124 @@ +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +import threading +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from typing import List +from lightllm.utils.infer_utils import set_random_seed +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams, requests_mapping +from lightllm.server.io_struct import ReqRunStatus, FinishStatus +from lightllm.server.pd_io_struct import UpKVStatus +from lightllm.utils.log_utils import init_logger +from ..pre_process import prepare_prefill_inputs, prepare_decode_inputs +from ..post_process import sample +from .up_status import UpStatusManager +from rpyc.utils.server import ThreadedServer +from lightllm.common.basemodel.infer_lock import g_infer_state_lock, g_router_lock +from .decode_task_cache import g_success_kv_move_task_cache + +logger = init_logger(__name__) + + +class ContinuesBatchBackendForDecodeNode(ModeBackend): + def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + super().__init__() + self.info_queue: mp.Queue = info_queue + self.mem_queue: mp.Queue = mem_queue + + def init_custom(self): + self.lock_nccl_group = dist.new_group(backend="gloo") + from .decode_infer_rpyc import PDDecodeInferRpcServer + + t = ThreadedServer(PDDecodeInferRpcServer(self), port=self.pd_rpyc_port, protocol_config={"allow_pickle": True}) + threading.Thread(target=lambda: t.start(), daemon=True).start() + return + + @calculate_time(show=False, min_cost_ms=300) + def prefill_batch(self, batch_id): + """ + 检查请求的 kv len 将可能有问题的请求立即结束掉 + """ + output_dict = {} + batch: InferBatch = self.cache.pop(batch_id) + + g_infer_state_lock.acquire() + remove_count = 0 + estimated_peak_token_count = 0 + for request_id in batch.request_ids: + if request_id in g_success_kv_move_task_cache: + task, share_node, _ = g_success_kv_move_task_cache.pop(request_id) + self.radix_cache.dec_node_ref_counter(share_node) + req_all_len = len(task.input_tokens) + task.decode_node.max_new_tokens + remove_count += req_all_len + estimated_peak_token_count += req_all_len + else: + # 对于不合法的请求,直接模拟将其finished掉 + req_obj: InferReq = requests_mapping[request_id] + req_obj.finish_status = FinishStatus.FINISHED_STOP + metadata = { + "id": 0, + "logprob": 0.0, + } + output_dict[req_obj.r_id] = ( + req_obj.req_status, + req_obj.cur_kv_len, + req_obj.get_output_len(), + [(0, metadata)], + req_obj.finish_status.value, # 转化为整数,避免传送大对象, + None, + ) + logger.error( + f"req_id: {req_obj.group_req_id} forced to finished, it not in g_success_kv_move_task_cache" + ) + + if self.tp_rank == 0: + with g_router_lock.obj: + self.shared_token_load.add_frozened_token_count(-remove_count) + self.shared_token_load.add_estimated_peak_token_count(estimated_peak_token_count) + g_infer_state_lock.release() + + self.cache[batch.batch_id] = batch + return output_dict + + @calculate_time(show=True, min_cost_ms=200) + def decode_batch(self, batch_id): + return self.forward(batch_id, is_prefill=False) + + def forward(self, batch_id, is_prefill): + # special code for return all prompt_logprobs + output_dict = {} + batch: InferBatch = self.cache.pop(batch_id) + if is_prefill: + kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal) + else: + kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache) + + logits = self.model.forward(**kwargs) + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + # prefill and decode is same + req_obj: InferReq = req_obj + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + req_obj.update_finish_status(self.eos_id) + + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[req_obj.r_id] = ( + req_obj.req_status, + req_obj.cur_kv_len, + req_obj.get_output_len(), + [(int(next_token_id), metadata)], + req_obj.finish_status.value, # 转化为整数,避免传送大对象, + None, + ) # 请求状态, 当前占用的kv的长度, 当前输出token的数量, 输出的token的id和元信息列表, 是否推理结束的状态, 额外保留参数 + + self.cache[batch.batch_id] = batch + return output_dict diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_infer_rpyc.py new file mode 100644 index 000000000..0a54828da --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_infer_rpyc.py @@ -0,0 +1,154 @@ +import torch +import torch.distributed as dist +import rpyc +import time +from typing import Dict, List, Tuple, Optional, Union +from rpyc.utils.classic import obtain +from .decode_impl import ContinuesBatchBackendForDecodeNode +from lightllm.common.basemodel.infer_lock import acquire_lock_until_ready, release_acquired_lock, g_router_lock +from .decode_task_cache import g_kv_move_task_cache, g_success_kv_move_task_cache +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class PDDecodeInferRpcServer(rpyc.Service): + def __init__(self, backend: ContinuesBatchBackendForDecodeNode) -> None: + super().__init__() + self.backend = backend + self.rank_id = self.backend.tp_rank + return + + def on_connect(self, conn): + torch.cuda.set_device(f"cuda:{self.rank_id}") + return + + def judge_token_is_ok(self, key_len, max_new_token): + if self.rank_id == 0: + with g_router_lock.obj: + shared_token_load = self.backend.shared_token_load + peak_num = shared_token_load.get_estimated_peak_token_count() + peak_num += shared_token_load.get_frozened_token_count() + peak_num += key_len + max_new_token + + if peak_num < self.backend.get_max_total_token_num(): + object_list = [True] + shared_token_load.add_frozened_token_count(key_len + max_new_token) + else: + object_list = [False] + dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group) + else: + object_list = [None] + dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group) + return object_list[0] + + def recover_frozen_token(self, key_len, max_new_token): + if self.rank_id == 0: + with g_router_lock.obj: + shared_token_load = self.backend.shared_token_load + shared_token_load.add_frozened_token_count(-(key_len + max_new_token)) + return + + # 返回 None 代表服务繁忙已经无法调度新的请求进入了 + def exposed_alloc_to_frozen_some_tokens(self, move_task: KVMoveTask) -> Optional[List[int]]: + logger.info("exposed_alloc_to_frozen_some_tokens start") + move_task = obtain(move_task) + acquire_lock_until_ready(self.backend.lock_nccl_group) + try: + is_ok = self.judge_token_is_ok(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) + if not is_ok: + return None + + key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") + tree_node, kv_len, fused_token_indexes = self.backend.radix_cache.match_prefix(key, update_refs=True) + # 如果没匹配到,说明长度是0, 将fused_token_indexes做一下转换 + fused_token_indexes = [] if fused_token_indexes is None else fused_token_indexes.tolist() + need_len = len(move_task.input_tokens) - kv_len + if need_len == 0: + alloc_token_indexes = [] + else: + self.backend.radix_cache.free_radix_cache_to_get_enough_token(need_len) + alloc_token_indexes = self.backend.model.mem_manager.alloc(need_len) + if alloc_token_indexes is not None: + alloc_token_indexes = alloc_token_indexes.detach().cpu().tolist() + + if alloc_token_indexes is None: + self.backend.radix_cache.dec_node_ref_counter(tree_node) + self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) + return None + + move_task.decode_token_indexes = alloc_token_indexes + move_task.move_kv_len = need_len + + g_kv_move_task_cache[move_task.group_request_id] = (move_task, tree_node, fused_token_indexes) + return move_task.decode_token_indexes + except BaseException as e: + logger.exception(str(e)) + return -1 + finally: + release_acquired_lock() + logger.info("exposed_alloc_to_frozen_some_tokens end") + + def exposed_put_kv_received_to_radix_cache(self, group_req_id: int): + group_req_id = obtain(group_req_id) + acquire_lock_until_ready(self.backend.lock_nccl_group) + move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id) + radix_cache = self.backend.radix_cache + key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") + value = torch.tensor(fused_token_indexes + move_task.decode_token_indexes, dtype=torch.int64, device="cpu") + prefix_len = radix_cache.insert(key, value) + assert len(fused_token_indexes) <= prefix_len + self.backend.model.mem_manager.free(value[len(fused_token_indexes) : prefix_len].cuda()) + self.backend.radix_cache.dec_node_ref_counter(tree_node) + + # 申请一段key,把 radix cache 锁住,防止极端情况下被刷掉, decode 端通过减两次引用计数来修正。 + tree_node, kv_len, _ = self.backend.radix_cache.match_prefix(key, update_refs=True) + assert len(key) == kv_len + g_success_kv_move_task_cache[group_req_id] = (move_task, tree_node, time.time()) + release_acquired_lock() + return + + def exposed_fail_to_realese_forzen_tokens(self, group_req_id: int): + group_req_id = obtain(group_req_id) + acquire_lock_until_ready(self.backend.lock_nccl_group) + move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id) + value = torch.tensor(move_task.decode_token_indexes, dtype=torch.int64, device="cpu") + self.backend.model.mem_manager.free(value.cuda()) + self.backend.radix_cache.dec_node_ref_counter(tree_node) + self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) + release_acquired_lock() + return + + def exposed_put_mem_manager_to_mem_queue(self): + self.backend.mem_queue.put(self.backend.model.mem_manager) + logger.info("put mem manager to info_queues ok") + return + + def exposed_unfrozen_time_out_reqs_tokens(self): + acquire_lock_until_ready(self.backend.lock_nccl_group) + if self.rank_id == 0: + need_release_reqs = [] + for req_id, (task, tree_node, time_mark) in g_success_kv_move_task_cache.items(): + # 4s 这个请求都没有被调度使用,就会主动被删除掉锁定,释放其锁定的token + if time.time() - time_mark > 4: + need_release_reqs.append(req_id) + logger.info(f"kv time out reqs: {need_release_reqs}") + dist.broadcast_object_list([need_release_reqs], src=0, group=self.backend.lock_nccl_group) + else: + receive_objs = [None] + dist.broadcast_object_list(receive_objs, src=0, group=self.backend.lock_nccl_group) + need_release_reqs = receive_objs[0] + + remove_tokens = 0 + for req_id in need_release_reqs: + task, tree_node, _ = g_success_kv_move_task_cache.pop(req_id) + self.backend.radix_cache.dec_node_ref_counter(tree_node) + remove_tokens += len(task.input_tokens) + task.decode_node.max_new_tokens + + if self.rank_id == 0 and remove_tokens != 0: + with g_router_lock.obj: + self.backend.shared_token_load.add_frozened_token_count(-remove_tokens) + + release_acquired_lock() + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_kv_move_manager.py new file mode 100644 index 000000000..640110d46 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_kv_move_manager.py @@ -0,0 +1,258 @@ +import rpyc +import asyncio +import sys +import os +import signal +import time +import psutil +import threading +from rpyc.utils.classic import obtain +from dataclasses import dataclass +from typing import List, Dict, Optional +from rpyc import ThreadedServer +from lightllm.utils.log_utils import init_logger +from .decode_infer_rpyc import PDDecodeInferRpcServer +from lightllm.common.mem_manager import MemoryManager +import torch.multiprocessing as mp +from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus +from lightllm.utils.retry_utils import retry +import numpy as np +import queue +from rpyc import AsyncResult +from ..prefill_node_impl.prefill_kv_move_manager import DecodeBusyError + +logger = init_logger(__name__) + +thread_local_data = threading.local() + + +@dataclass +class TransProcessObj: + prefill_node_id: str = None + process: mp.Process = None + task_in_queue: mp.Queue = None + task_out_queue: mp.Queue = None + nccl_ip: str = None + nccl_port: str = None + device_index: int = None + + def create(self, prefill_node_id: str, nccl_ip: str, nccl_port: int, manager: "DecodeKVMoveManager"): + from .decode_trans_process import start_decode_trans_process + + task_in_queue = mp.Queue() + task_out_queue = mp.Queue() + device_index = manager.get_next_device_index() + proc = start_decode_trans_process( + manager.args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, manager.mem_queues + ) + assert task_out_queue.get(timeout=30) == "proc_start" + with manager.infer_rpyc_lock: + for obj in manager.infer_rpyc_objs: + obj.put_mem_manager_to_mem_queue() + assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" + assert task_out_queue.get(timeout=60) == "nccl_ok" + + self.prefill_node_id = prefill_node_id + self.process = proc + self.task_in_queue = task_in_queue + self.task_out_queue = task_out_queue + self.nccl_ip = nccl_ip + self.nccl_port = nccl_port + self.device_index = device_index + return + + def check_trans_process(self): + process = psutil.Process(self.process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + raise Exception(f"trans process: {self.process.pid} is dead") + return + + def __del__(self): + # 强制关闭连接和杀掉传输进程 + if self.process is not None: + logger.warning(f"trans kv process {self.process.pid} is killed") + os.kill(self.process.pid, signal.SIGKILL) + pass + + +class DecodeKVMoveManager(rpyc.Service): + def __init__(self, args, info_queues: List[mp.Queue], mem_queues: List[mp.Queue]): + super().__init__() + self.args = args + self.info_queues = info_queues + self.mem_queues = mem_queues + self.infer_rpyc_lock = threading.Lock() + self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = [] + self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {} + for port in self.args.pd_tp_infer_rpyc_ports: + con = retry(max_attempts=20, wait_time=2)(rpyc.connect)("localhost", port, config={"allow_pickle": True}) + self.infer_rpyc_objs.append(con.root) + logger.info(f"rpyc connect to port: {port} ok") + + from .up_status import start_up_kv_status_process + + self.up_status_in_queue = mp.Queue() + self.up_status_out_queue = mp.Queue() + start_up_kv_status_process(self.args, self.up_status_in_queue, self.up_status_out_queue) + + # 开启tp个线程和队列来处理,每个队列处理一张卡上的任务 + self.task_queues = [queue.Queue() for _ in range(self.args.tp)] + for i in range(self.args.tp): + threading.Thread(target=self.handle_loop, args=(self.task_queues[i],), daemon=True).start() + threading.Thread(target=self.timer_loop, daemon=True).start() + return + + async def wait_all_future_finish(self, futures: List[AsyncResult]): + await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) + return + + def on_connect(self, conn): + thread_local_data.prefill_node_id = None + pass + + def on_disconnect(self, conn): + if thread_local_data.prefill_node_id is not None: + self.node_id_to_trans_obj.pop(thread_local_data.prefill_node_id, None) + logger.info(f"prefill node id {thread_local_data.prefill_node_id} disconnect") + pass + + def exposed_check_alive(self): + # 用于 prefill node check 通信连接的状态。 + return + + def exposed_build_trans_process(self, prefill_node_id, nccl_ip, nccl_port): + prefill_node_id, nccl_ip, nccl_port = list(map(obtain, [prefill_node_id, nccl_ip, nccl_port])) + thread_local_data.prefill_node_id = prefill_node_id + + logger.info(f"build trans infos {prefill_node_id} {nccl_ip} {nccl_port}") + if prefill_node_id in self.node_id_to_trans_obj: + self.node_id_to_trans_obj.pop(prefill_node_id, None) + tran_obj = TransProcessObj() + tran_obj.create(prefill_node_id, nccl_ip, nccl_port, self) + self.node_id_to_trans_obj[prefill_node_id] = tran_obj + return + + # 返回 None 代表繁忙, 放弃该任务的 kv 传送 + def exposed_request_data_transfer(self, task: KVMoveTask) -> Optional[int]: + task = obtain(task) + logger.info(f"exposed_request_data_transfer in {task.to_decode_log_info()}") + try: + trans_obj = self.get_trans_obj(task) + device_index = trans_obj.device_index + assert trans_obj is not None + value_list = [] + with self.infer_rpyc_lock: + futures: List[AsyncResult] = [] + for conn in self.infer_rpyc_objs: + futures.append(rpyc.async_(conn.alloc_to_frozen_some_tokens)(task)) + asyncio.run(self.wait_all_future_finish(futures)) + value_list = [obtain(future.value) for future in futures] + + # 代表服务很繁忙,申请不到资源,需要拒绝 + if value_list[0] is None: + raise DecodeBusyError("token is full, busy") + + task.decode_token_indexes = value_list[0] + task.move_kv_len = len(value_list[0]) + except DecodeBusyError as e: + logger.error(str(e)) + return None + + except BaseException as e: + # 移除通信对象 + self.node_id_to_trans_obj.pop(task.prefill_node_id, None) + trans_obj = None + logger.exception(str(e)) + raise e + + self.task_queues[device_index].put(task) + return task.move_kv_len + + def get_next_device_index(self): + counts = [0 for _ in range(self.args.tp)] + for obj in self.node_id_to_trans_obj.values(): + counts[obj.device_index] += 1 + device_index = int(np.argmin(counts)) + return device_index + + def get_trans_obj(self, task: KVMoveTask): + return self.node_id_to_trans_obj[task.prefill_node_id] + + def handle_loop(self, task_queue: queue.Queue): + try: + while True: + task = task_queue.get() + if not isinstance(task, KVMoveTask): + logger.error("receive task type is not KVMoveTask") + sys.exit(-1) + + logger.info(f"deocode node get task {task.to_decode_log_info()}") + try: + trans_obj = self.get_trans_obj(task) + trans_obj.task_in_queue.put(task, timeout=10) + assert trans_obj.task_out_queue.get(timeout=30) == "ok" + logger.info(f"deocode node transfer kv ok {task.to_decode_log_info()}") + # 成功了将 + with self.infer_rpyc_lock: + futures: List[AsyncResult] = [] + for conn in self.infer_rpyc_objs: + futures.append(rpyc.async_(conn.put_kv_received_to_radix_cache)(task.group_request_id)) + asyncio.run(self.wait_all_future_finish(futures)) + + logger.info(f"decode node put kv to radix cache ok, req_id: {task.id()}") + self.up_status_in_queue.put(UpKVStatus(group_request_id=task.group_request_id)) + logger.info("decode node up kv status finished") + except BaseException as e: + logger.exception(str(e)) + # 失败了也需要释放锁定的 token + with self.infer_rpyc_lock: + futures: List[AsyncResult] = [] + for conn in self.infer_rpyc_objs: + futures.append(rpyc.async_(conn.fail_to_realese_forzen_tokens)(task.group_request_id)) + asyncio.run(self.wait_all_future_finish(futures)) + logger.error(f"decode kv move task {task.to_decode_log_info()} has error, remove the trans_obj") + self.node_id_to_trans_obj.pop(task.prefill_node_id, None) + finally: + # 去除引用否则进程无法自动退出 + trans_obj = None + except BaseException as e: + logger.exception(str(e)) + raise e + + def timer_loop(self): + while True: + with self.infer_rpyc_lock: + futures: List[AsyncResult] = [] + for conn in self.infer_rpyc_objs: + futures.append(rpyc.async_(conn.unfrozen_time_out_reqs_tokens)()) + asyncio.run(self.wait_all_future_finish(futures)) + time.sleep(3.5) + + +def _init_env(args, info_queues: List[mp.Queue], mem_queues: List[mp.Queue], event: mp.Event): + # 注册graceful 退出的处理 + from lightllm.utils.graceful_utils import graceful_registry + import inspect + + graceful_registry(inspect.currentframe().f_code.co_name) + + manager = DecodeKVMoveManager(args, info_queues, mem_queues) + t = ThreadedServer(manager, port=args.pd_decode_rpyc_port, protocol_config={"allow_pickle": True}) + threading.Thread(target=lambda: t.start(), daemon=True).start() + + event.set() + + # 进入主循环 + while True: + time.sleep(10) + return + + +def start_decode_kv_move_manager_process(args, info_queues: List[mp.Queue], mem_queues: List[mp.Queue]): + event = mp.Event() + proc = mp.Process(target=_init_env, args=(args, info_queues, mem_queues, event)) + proc.start() + event.wait() + assert proc.is_alive() + logger.info("prefill kv move manager process started") + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_task_cache.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_task_cache.py new file mode 100644 index 000000000..48df4b86f --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_task_cache.py @@ -0,0 +1,10 @@ +# 这个里面声明了一个全局变量,主要用于推理进程缓存发送给其他进程的Kv move 任务的缓存数据 +# 为了减少一些调用时候的序列化开销。有些调用就只需要传输一个请求id就可以了,不用传输特别的 +# 数据了,提升rpyc 调用的速度, 只用在 decode_impl.py 和 decode_infer_rpyc.py 文件中 +from typing import Dict, List, Tuple +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.router.dynamic_prompt.radix_cache import TreeNode + +g_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode, List[int]]] = {} + +g_success_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode, float]] = {} # 第三个float代表的是时间,用于判断过期条件。 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_trans_process.py new file mode 100644 index 000000000..989d9c217 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_trans_process.py @@ -0,0 +1,98 @@ +import torch +import time +import sys +from typing import List, Dict +from lightllm.utils.log_utils import init_logger +from lightllm.common.mem_manager import MemoryManager +import torch.multiprocessing as mp +from lightllm.server.pd_io_struct import KVMoveTask + +logger = init_logger(__name__) + + +def _init_env( + args, + device_index: int, + nccl_ip, + nccl_port, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + mem_queues: List[mp.Queue], +): + import os + + # os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_MAX_NCHANNELS"] = "2" + os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" + os.environ["NCCL_SOCKET_NTHREADS"] = "1" + torch.backends.cudnn.enabled = False + + try: + # 注册graceful 退出的处理 + from lightllm.utils.graceful_utils import graceful_registry + import inspect + + graceful_registry(inspect.currentframe().f_code.co_name) + task_out_queue.put("proc_start") + mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] + assert len(mem_managers) == args.tp + task_out_queue.put("get_mem_managers_ok") + import torch.distributed as dist + from datetime import timedelta + + dist.init_process_group( + "nccl", init_method=f"tcp://{nccl_ip}:{nccl_port}", rank=1, world_size=2, timeout=timedelta(seconds=60) + ) + task_out_queue.put("nccl_ok") + while True: + move_task: KVMoveTask = task_in_queue.get() + try: + start = time.time() + if move_task.move_kv_len != 0: + cur_mem = mem_managers[device_index] + recive_buffer = cur_mem.get_layer_buffer_by_token_num(move_task.move_kv_len) + logger.info(f"trans start: {move_task.to_decode_log_info()}") + for i, mem in enumerate(mem_managers): + for layer_index in range(mem.layer_num): + dist.recv(recive_buffer, src=0) + if i == device_index: + mem.write_to_layer_buffer(move_task.decode_token_indexes, recive_buffer, layer_index) + else: + move_size = recive_buffer.numel() + new_recive_buffer = mem.kv_move_buffer.view(-1)[0:move_size].view(recive_buffer.shape) + from torch.cuda import comm + + comm.broadcast(recive_buffer, out=[new_recive_buffer]) + mem.write_to_layer_buffer( + move_task.decode_token_indexes, new_recive_buffer, layer_index + ) + logger.info(f"trans finished: {move_task.to_decode_log_info()}") + torch.cuda.synchronize() + logger.info(f"trans cost time: {(time.time() - start)}, {move_task.to_decode_log_info()}") + task_out_queue.put("ok") + except BaseException as e: + logger.exception(str(e)) + task_out_queue.put("fail") + raise e + except BaseException as e: + logger.exception(str(e)) + sys.exit(-1) + return + + +def start_decode_trans_process( + args, + device_index: int, + nccl_ip, + nccl_port, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + mem_queues: List[mp.Queue], +): + proc = mp.Process( + target=_init_env, args=(args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, mem_queues) + ) + proc.start() + assert proc.is_alive() + logger.info(f"decode trans kv process start, nccl_ip: {nccl_ip}, nccl_port: {nccl_port}") + return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/up_status.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/up_status.py new file mode 100644 index 000000000..69c08a84a --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/up_status.py @@ -0,0 +1,67 @@ +import time +import json +import asyncio +import threading +import websockets +from typing import List +from dataclasses import asdict +from lightllm.server.pd_io_struct import UpKVStatus +from lightllm.utils.log_utils import init_logger +import torch.multiprocessing as mp + +logger = init_logger(__name__) + + +class UpStatusManager: + def __init__(self, args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): + self.args = args + self.task_queue: mp.Queue[UpKVStatus] = task_in_queue + self.task_out_queue = task_out_queue + self.daemon_thread = threading.Thread(target=self.thread_loop, daemon=True) + self.daemon_thread.start() + + def thread_loop(self): + asyncio.run(self.loop()) + + async def loop(self): + while True: + try: + uri = f"ws://{self.args.pd_master_ip}:{self.args.pd_master_port}/kv_move_status" + async with websockets.connect(uri) as websocket: + while True: + try: + loop = asyncio.get_event_loop() + upkv_status: UpKVStatus = await loop.run_in_executor(None, self.task_queue.get) + await websocket.send(json.dumps(asdict(upkv_status))) + logger.info(f"up status: {upkv_status}") + # self.task_out_queue.put("ok") + except BaseException as e: + logger.error(str(e)) + # self.task_out_queue.put("fail") + raise e + + except Exception as e: + logger.error(f"connetion to pd_master has error: {str(e)}") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to pd_master") + + +def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): + from lightllm.utils.graceful_utils import graceful_registry + import inspect + + graceful_registry(inspect.currentframe().f_code.co_name) + up_kv_manager = UpStatusManager(args, task_in_queue, task_out_queue) + logger.info(f"up kv manager {str(up_kv_manager)} start ok") + while True: + time.sleep(10) + return + + +def start_up_kv_status_process(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue): + proc = mp.Process(target=_init_env, args=(args, task_in_queue, task_out_queue)) + proc.start() + assert proc.is_alive() + logger.info("up_kv_status_process start") + return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pre_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pre_process.py index 398042dcb..2976e80be 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pre_process.py @@ -5,6 +5,7 @@ from lightllm.utils.infer_utils import calculate_time from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.mem_manager import MemoryManager +from lightllm.common.basemodel.infer_lock import g_infer_state_lock # @calculate_time(show=True, min_cost_ms=1) def prepare_prefill_inputs(batch: InferBatch, radix_cache: RadixCache, is_multimodal=False): @@ -46,11 +47,20 @@ def prepare_prefill_inputs(batch: InferBatch, radix_cache: RadixCache, is_multim nopad_b_start_loc = torch.tensor(nopad_b_start_loc, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") + + # dynamic prompt cache 准备 token + g_infer_state_lock.acquire() + if radix_cache is not None: + radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = batch.req_manager.mem_manager.alloc(input_ids.shape[0]) + g_infer_state_lock.release() + kwargs = { "batch_size": len(batch), "total_token_num": nopad_total_token_num, "max_len_in_batch": nopad_max_len_in_batch, "input_ids": input_ids, + "mem_indexes": mem_indexes, "b_req_idx": nopad_b_req_idx, "b_start_loc": nopad_b_start_loc, "b_seq_len": nopad_b_seq_len, @@ -60,10 +70,6 @@ def prepare_prefill_inputs(batch: InferBatch, radix_cache: RadixCache, is_multim if is_multimodal: kwargs["multimodal_params"] = batch_multimodal_params - # dynamic prompt cache 准备 token - if radix_cache is not None: - radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - return kwargs, run_reqs @@ -96,18 +102,23 @@ def prepare_decode_inputs(batch: InferBatch, radix_cache: RadixCache): nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") nopad_b_start_loc = torch.tensor(nopad_b_start_loc, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + + # dynamic prompt cache 准备 token + g_infer_state_lock.acquire() + if radix_cache is not None: + radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = batch.req_manager.mem_manager.alloc(input_ids.shape[0]) + g_infer_state_lock.release() + kwargs = { "batch_size": len(batch), "total_token_num": nopad_total_token_num, "max_len_in_batch": nopad_max_len_in_batch, "input_ids": input_ids, + "mem_indexes": mem_indexes, "b_req_idx": nopad_b_req_idx, "b_start_loc": nopad_b_start_loc, "b_seq_len": nopad_b_seq_len, "is_prefill": False, } - # dynamic prompt cache 准备 token - if radix_cache is not None: - radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - return kwargs, run_reqs diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/__init__.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/__init__.py new file mode 100644 index 000000000..4100e14ed --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/__init__.py @@ -0,0 +1,2 @@ +from .prefill_trans_process import start_prefill_trans_process +from .prefill_kv_move_manager import start_prefill_kv_move_manager_process diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_impl.py new file mode 100644 index 000000000..9d9153869 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_impl.py @@ -0,0 +1,130 @@ +import threading +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from typing import List +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.utils.infer_utils import set_random_seed +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams, requests_mapping +from lightllm.server.io_struct import ReqRunStatus, FinishStatus +from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo +from lightllm.utils.log_utils import init_logger +from ..pre_process import prepare_prefill_inputs, prepare_decode_inputs +from ..post_process import sample +from lightllm.common.basemodel.infer_lock import g_router_lock, g_infer_state_lock +from rpyc.utils.server import ThreadedServer +from .prefill_task_cache import g_kv_move_task_cache + +logger = init_logger(__name__) + + +class ContinuesBatchBackendForPrefillNode(ModeBackend): + def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + super().__init__() + self.info_queue: mp.Queue = info_queue + self.mem_queue: mp.Queue = mem_queue + + def init_custom(self): + self.lock_nccl_group = dist.new_group(backend="gloo") + from .prefill_infer_rpyc import PDPrefillInferRpcServer + + t = ThreadedServer( + PDPrefillInferRpcServer(self), port=self.pd_rpyc_port, protocol_config={"allow_pickle": True} + ) + threading.Thread(target=lambda: t.start(), daemon=True).start() + return + + @calculate_time(show=False, min_cost_ms=300) + def prefill_batch(self, batch_id): + ans = self.forward(batch_id, is_prefill=True) + return ans + + @calculate_time(show=True, min_cost_ms=200) + def decode_batch(self, batch_id): + return self.forward(batch_id, is_prefill=False) + + def forward(self, batch_id, is_prefill): + # special code for return all prompt_logprobs + output_dict = {} + batch: InferBatch = self.cache.pop(batch_id) + if is_prefill: + kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal) + else: + kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache) + + logits = self.model.forward(**kwargs) + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + # prefill and decode is same + req_obj: InferReq = req_obj + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + req_obj.update_finish_status(self.eos_id) + + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + + output_dict[req_obj.r_id] = ( + req_obj.req_status, + req_obj.cur_kv_len, + req_obj.get_output_len(), + [(int(next_token_id), metadata)], + req_obj.finish_status.value, # 转化为整数,避免传送大对象, + None, + ) # 请求状态, 当前占用的kv的长度, 当前输出token的数量, 输出的token的id和元信息列表, 是否推理结束的状态, 额外保留参数 + + if is_prefill: + self.prefill_req_handle_and_frozen_tokens(run_reqs) + + self.cache[batch.batch_id] = batch + return output_dict + + def prefill_req_handle_and_frozen_tokens(self, run_reqs: List[InferReq]): + # 提前在radix cache中回收相关的信息,并添加引用信息 + logger.info("prefill_req_handle_and_frozen_tokens") + g_infer_state_lock.acquire() + try: + for req in run_reqs: + key = torch.tensor(req.input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + prefix_len = self.radix_cache.insert(key, value) + self.model.mem_manager.free(self.model.req_manager.req_to_token_indexs[req.req_idx][:prefix_len]) + if req.shared_kv_node is not None: + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + req.cur_kv_len = 0 + if req.sampling_param.move_kv_to_decode_node is not None: + if self.tp_rank == 0: + g_router_lock.acquire() + self.shared_token_load.add_frozened_token_count(len(key)) + g_router_lock.release() + + share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=True) + assert len(key) == len(value) + # 将下面的请求放入到任务队列中, 注意要使用raidx cache 返回的value + decode_node_info = DecodeNodeInfo(**req.sampling_param.move_kv_to_decode_node) + task = KVMoveTask( + group_request_id=req.group_req_id, + input_tokens=key.tolist(), + prefill_token_indexes=value.tolist(), + decode_token_indexes=None, + prefill_node_id=self.args.pd_node_id, + decode_node=decode_node_info, + move_kv_len=None, + ) + g_kv_move_task_cache[task.group_request_id] = (task, share_node) + # 只有 0 进程发送真正的数据到队列中。 + if self.tp_rank == 0: + self.info_queue.put(task) + except BaseException as e: + logger.exception(str(e)) + g_infer_state_lock.release() + logger.info("prefill_req_handle_and_frozen_tokens end") + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_infer_rpyc.py new file mode 100644 index 000000000..0e3b31672 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_infer_rpyc.py @@ -0,0 +1,45 @@ +import torch +import torch.distributed as dist +import rpyc +from typing import Dict, List, Tuple +from rpyc.utils.classic import obtain +from .prefill_impl import ContinuesBatchBackendForPrefillNode +from lightllm.common.basemodel.infer_lock import g_router_lock, acquire_lock_until_ready, release_acquired_lock +from .prefill_task_cache import g_kv_move_task_cache +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class PDPrefillInferRpcServer(rpyc.Service): + def __init__(self, backend: ContinuesBatchBackendForPrefillNode) -> None: + super().__init__() + self.backend = backend + return + + def on_connect(self, conn): + self.rank_id = dist.get_rank() + torch.cuda.set_device(f"cuda:{self.rank_id}") + return + + # pd 分离模式会使用的一些接口,用于做一些全局信息管理 + def exposed_remove_req_refs_from_prompt_cache(self, group_req_id: int): + group_req_id = obtain(group_req_id) + acquire_lock_until_ready(self.backend.lock_nccl_group) + task, share_node = g_kv_move_task_cache.pop(group_req_id) + if share_node is not None: + self.backend.radix_cache.dec_node_ref_counter(share_node) + logger.info(f"unfrozen tokens for req id: {group_req_id}") + + # 更新元数据 + if self.rank_id == 0: + with g_router_lock.obj: + self.backend.shared_token_load.add_frozened_token_count(-len(task.input_tokens)) + + release_acquired_lock() + return + + def exposed_put_mem_manager_to_mem_queue(self): + self.backend.mem_queue.put(self.backend.model.mem_manager) + logger.info("put mem manager to mem_queue ok") + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_kv_move_manager.py new file mode 100644 index 000000000..138dff759 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_kv_move_manager.py @@ -0,0 +1,217 @@ +import asyncio +import rpyc +import sys +import os +import gc +import signal +import copy +import numpy as np +import psutil +from dataclasses import dataclass +from typing import List, Dict +from lightllm.utils.log_utils import init_logger +from .prefill_infer_rpyc import PDPrefillInferRpcServer +from lightllm.common.mem_manager import MemoryManager +import torch.multiprocessing as mp +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.utils.net_utils import find_available_port +from lightllm.utils.retry_utils import retry +from rpyc.utils.classic import obtain +from rpyc import AsyncResult +from lightllm.utils.net_utils import get_hostname_ip + +logger = init_logger(__name__) + + +class DecodeBusyError(Exception): + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +@dataclass +class TransProcessObj: + decode_node_id: str = None + rpyc_conn: object = None # rpyc_con 的连接对象 + process: mp.Process = None + task_in_queue: mp.Queue = None + task_out_queue: mp.Queue = None + nccl_ip: str = None + nccl_port: str = None + device_index: str = None # 使用的gpu序号 + + def create( + self, decode_node_id: str, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" + ): + con = rpyc.connect( + host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True + ) + nccl_ip = manager.host_ip + nccl_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) + if nccl_port is None: + raise Exception("no pd nccl port can be used") + + from .prefill_trans_process import start_prefill_trans_process + + device_index = manager.get_next_device_index() # 分配 trans 进程使用的显卡 + task_in_queue = mp.Queue() + task_out_queue = mp.Queue() + proc = start_prefill_trans_process( + manager.args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, manager.mem_queues + ) + assert task_out_queue.get(timeout=30) == "proc_start" + for obj in manager.infer_rpyc_objs: + obj.put_mem_manager_to_mem_queue() + assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" + prefill_node_id = manager.args.pd_node_id + con.root.build_trans_process(prefill_node_id, nccl_ip, nccl_port) # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 + assert task_out_queue.get(timeout=60) == "nccl_ok" + + self.decode_node_id = decode_node_id + self.rpyc_conn = con + self.process = proc + self.task_in_queue = task_in_queue + self.task_out_queue = task_out_queue + self.nccl_port = nccl_port + self.nccl_ip = nccl_ip + self.device_index = device_index + return + + def check_trans_process(self): + process = psutil.Process(self.process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + raise Exception(f"trans process: {self.process.pid} is dead") + return + + def __del__(self): + # 强制关闭连接和杀掉传输进程 + if self.process is not None: + logger.warning(f"prefill trans process {self.process.pid} is killed") + os.kill(self.process.pid, signal.SIGKILL) + pass + + +class PrefillKVMoveManager: + def __init__(self, args, info_queues: List[mp.Queue], mem_queues: List[mp.Queue]): + self.args = args + self.info_queues = info_queues + self.mem_queues = mem_queues + self.infer_rpyc_objs: List[PDPrefillInferRpcServer] = [] + self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {} + for port in self.args.pd_tp_infer_rpyc_ports: + con = retry(max_attempts=20, wait_time=2)(rpyc.connect)("localhost", port, config={"allow_pickle": True}) + self.infer_rpyc_objs.append(con.root) + logger.info(f"rpyc connect to infer rpyc port: {port} ok") + self.host_ip = get_hostname_ip() + if self.host_ip is None: + self.host_ip = args.host + return + + def get_next_device_index(self): + counts = [0 for _ in range(self.args.tp)] + for obj in self.node_id_to_trans_obj.values(): + counts[obj.device_index] += 1 + device_index = int(np.argmin(counts)) + return device_index + + def get_trans_obj(self, task: KVMoveTask): + if task.decode_node.node_id not in self.node_id_to_trans_obj: + # 先遍历删除老的不能用的连接 + self.remove_dead_trans_obj() + trans_obj = TransProcessObj() + trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self) + self.node_id_to_trans_obj[task.decode_node.node_id] = trans_obj + return self.node_id_to_trans_obj[task.decode_node.node_id] + + def remove_dead_trans_obj(self): + del_node_ids = [] + for node_id, t_obj in self.node_id_to_trans_obj.items(): + try: + t_obj.rpyc_conn.root.check_alive() + except BaseException as e: + logger.error(f"check error {str(e)}") + del_node_ids.append(node_id) + for node_id in del_node_ids: + self.node_id_to_trans_obj.pop(node_id) + + if len(del_node_ids) != 0: + gc.collect() + return + + def handle_loop(self): + try: + while True: + move_task = self.info_queues[0].get() + if not isinstance(move_task, KVMoveTask): + logger.error("receive type is not KVMoveTask") + sys.exit(-1) + + logger.info(f"prefill node get task {move_task.to_prefill_log_info()}") + try: + trans_obj = self.get_trans_obj(move_task) + # 申请传输 + trans_move_task = copy.copy(move_task) + # 不需要发送prefill节点的token index信息给decode节点 + trans_move_task.prefill_token_indexes = None + # 申请发送,并收到发送长度 move_kv_len. + move_kv_len = obtain(trans_obj.rpyc_conn.root.request_data_transfer(trans_move_task)) + # 代表对方已经很繁忙了,放弃这次发送,改为用 + if move_kv_len is None: + raise DecodeBusyError(f"decode_node_id {trans_obj.decode_node_id} is busy") + + move_task.move_kv_len = move_kv_len + logger.info(f"prefill node request_data_transfer ok, {move_task.to_prefill_log_info()}") + # 开始传输直到完成 + trans_obj.task_in_queue.put(move_task, timeout=10) + assert trans_obj.task_out_queue.get(timeout=30) == "ok" + logger.info(f"prefill node transfer data ok, req_id: {move_task.id()}") + + except DecodeBusyError as e: + logger.error(str(e)) + + except BaseException as e: + logger.exception(str(e)) + logger.error(f"kv move task {move_task.to_prefill_log_info()} has error, remove the trans_obj") + self.node_id_to_trans_obj.pop(move_task.decode_node.node_id, None) + + finally: + # 去引用否则进程无法杀掉 + trans_obj = None + # 解除对prefill token的占用状态。 + futures: List[AsyncResult] = [] + for infer_rpyc in self.infer_rpyc_objs: + futures.append( + rpyc.async_(infer_rpyc.remove_req_refs_from_prompt_cache)(move_task.group_request_id) + ) + asyncio.run(self.wait_all_future_finish(futures)) + + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + raise e + + async def wait_all_future_finish(self, futures: List[AsyncResult]): + await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) + return + + +def _init_env(args, info_queues: List[mp.Queue], mem_queues: List[mp.Queue], event: mp.Event): + # 注册graceful 退出的处理 + from lightllm.utils.graceful_utils import graceful_registry + import inspect + + graceful_registry(inspect.currentframe().f_code.co_name) + + manager = PrefillKVMoveManager(args, info_queues, mem_queues) + event.set() + # 进入主循环 + manager.handle_loop() + return + + +def start_prefill_kv_move_manager_process(args, info_queues: List[mp.Queue], mem_queues: List[mp.Queue]): + event = mp.Event() + proc = mp.Process(target=_init_env, args=(args, info_queues, mem_queues, event)) + proc.start() + event.wait() + assert proc.is_alive() + logger.info("prefill kv move manager process started") + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_task_cache.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_task_cache.py new file mode 100644 index 000000000..afa8e87f4 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_task_cache.py @@ -0,0 +1,8 @@ +# 这个里面声明了一个全局变量,主要用于推理进程缓存发送给其他进程的Kv move 任务的缓存数据 +# 为了减少一些调用时候的序列化开销。有些调用就只需要传输一个请求id就可以了,不用传输特别的 +# 数据了,提升rpyc 调用的速度, 只用在 prefill_impl.py 和 prefill_infer_rpyc.py 文件中 +from typing import Dict, Tuple +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.router.dynamic_prompt.radix_cache import TreeNode + +g_kv_move_task_cache: Dict[int, Tuple[KVMoveTask, TreeNode]] = {} diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_trans_process.py new file mode 100644 index 000000000..ee83966ed --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/prefill_node_impl/prefill_trans_process.py @@ -0,0 +1,101 @@ +import torch +import time +import sys +from typing import List, Dict +from lightllm.utils.log_utils import init_logger +from lightllm.common.mem_manager import MemoryManager +import torch.multiprocessing as mp +from lightllm.server.pd_io_struct import KVMoveTask + +logger = init_logger(__name__) + + +# device_index 是用来指示,当前传输进程使用的用于数据传输的显卡id +# 当模型是多卡推理的时候,需要传输的 kv 需要先移动到 device_index +# 指定的显卡上,然后再进行传输,因为torch nccl 限制了只能操作一张显卡上的数据 + + +def _init_env( + args, + device_index: int, + nccl_ip, + nccl_port, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + mem_queues: List[mp.Queue], +): + import os + + # os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_MAX_NCHANNELS"] = "2" + os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" + os.environ["NCCL_SOCKET_NTHREADS"] = "1" + torch.backends.cudnn.enabled = False + + try: + # 注册graceful 退出的处理 + from lightllm.utils.graceful_utils import graceful_registry + import inspect + + graceful_registry(inspect.currentframe().f_code.co_name) + task_out_queue.put("proc_start") + mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] + assert len(mem_managers) == args.tp + task_out_queue.put("get_mem_managers_ok") + import torch.distributed as dist + from datetime import timedelta + + dist.init_process_group( + "nccl", init_method=f"tcp://{nccl_ip}:{nccl_port}", rank=0, world_size=2, timeout=timedelta(seconds=60) + ) + task_out_queue.put("nccl_ok") + while True: + move_task: KVMoveTask = task_in_queue.get() + try: + start = time.time() + if move_task.move_kv_len != 0: + logger.info(f"trans start: {move_task.to_prefill_log_info()}") + token_indexes = move_task.prefill_token_indexes[-move_task.move_kv_len :] + cur_mem = mem_managers[device_index] + for i, mem in enumerate(mem_managers): + for layer_index in range(mem.layer_num): + move_buffer = mem.read_from_layer_buffer(token_indexes, layer_index) + if i == device_index: + dist.send(move_buffer, dst=1) + else: + move_size = move_buffer.numel() + new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape) + from torch.cuda import comm + + comm.broadcast(move_buffer, out=[new_move_buffer]) + dist.send(new_move_buffer, dst=1) + logger.info(f"trans finished: {move_task.to_prefill_log_info()}") + torch.cuda.synchronize() + logger.info(f"trans cost time: {(time.time() - start)}, {move_task.to_prefill_log_info()}") + task_out_queue.put("ok") + except BaseException as e: + logger.exception(str(e)) + task_out_queue.put("fail") + raise e + except BaseException as e: + logger.exception(str(e)) + sys.exit(-1) + return + + +def start_prefill_trans_process( + args, + device_index: int, + nccl_ip, + nccl_port, + task_in_queue: mp.Queue, + task_out_queue: mp.Queue, + mem_queues: List[mp.Queue], +): + proc = mp.Process( + target=_init_env, args=(args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, mem_queues) + ) + proc.start() + assert proc.is_alive() + logger.info(f"trans kv process start, nccl_ip: {nccl_ip}, nccl_port: {nccl_port}") + return proc diff --git a/lightllm/server/router/model_infer/mode_backend/splitfuse/pre_process.py b/lightllm/server/router/model_infer/mode_backend/splitfuse/pre_process.py index c809be293..7a2a98f48 100644 --- a/lightllm/server/router/model_infer/mode_backend/splitfuse/pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/splitfuse/pre_process.py @@ -57,8 +57,14 @@ def splitfuse_prepare_decode_inputs(batch: InferBatch, splitfuse_block_size, rad input_ids.extend(req.input_token_ids[seq_len - split_len : seq_len]) input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") + + if radix_cache is not None: + radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = batch.req_manager.mem_manager.alloc(input_ids.shape[0]) + kwargs = { "input_ids": input_ids, + "mem_indexes": mem_indexes, "decode_req_num": decode_req_num, "decode_total_token_num": decode_total_token_num, "decode_b_req_idx": torch.tensor(decode_b_req_idx, dtype=torch.int32, device="cuda"), @@ -75,8 +81,4 @@ def splitfuse_prepare_decode_inputs(batch: InferBatch, splitfuse_block_size, rad "prefill_b_seq_len": torch.tensor(prefill_b_seq_len, dtype=torch.int32, device="cuda"), } - # dynamic prompt cache 准备 token - if radix_cache is not None: - radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - return kwargs, decode_reqs, prefill_reqs diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 2f6cc55c2..63420e7af 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -1,5 +1,7 @@ import asyncio import rpyc +import torch +import torch.multiprocessing as mp from datetime import timedelta from typing import Dict, List, Tuple from rpyc.utils.classic import obtain @@ -13,6 +15,8 @@ TokenHealingBackend, SimpleConstraintBackend, FirstTokenConstraintBackend, + ContinuesBatchBackendForPrefillNode, + ContinuesBatchBackendForDecodeNode, ) from lightllm.utils.log_utils import init_logger @@ -20,6 +24,13 @@ class ModelRpcServer(rpyc.Service): + def __init__(self, args, info_queue: mp.Queue, mem_queue: mp.Queue): + super().__init__() + self.args = args + self.info_queue = info_queue + self.mem_queue = mem_queue + return + def exposed_init_model(self, kvargs): self.world_size = kvargs["world_size"] if self.world_size != 1: @@ -35,11 +46,18 @@ def exposed_init_model(self, kvargs): is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False) if kvargs.get("args", None) is not None: is_simple_constraint_mode = kvargs.get("args", None).simple_constraint_mode + is_prefill_node = kvargs.get("args", None).run_mode == "prefill" + is_decode_node = kvargs.get("args", None).run_mode == "decode" else: is_simple_constraint_mode = False + is_prefill_node = False + is_decode_node = False # use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) - - if use_reward_model: + if is_prefill_node: + self.backend = ContinuesBatchBackendForPrefillNode(self.info_queue, self.mem_queue) + elif is_decode_node: + self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) + elif use_reward_model: self.backend = RewardModelBackend() elif is_splitfuse_mode: self.backend = SplitFuseBackend() @@ -225,28 +243,31 @@ async def get_max_total_token_num(self): return ans -def _init_env(port): +def _init_env(args, port, info_queue, mem_queue, router_lock): # 注册graceful 退出的处理 from lightllm.utils.graceful_utils import graceful_registry import inspect graceful_registry(inspect.currentframe().f_code.co_name) + # 将调度锁注册到全局的共享变量中 + from lightllm.common.basemodel.infer_lock import g_router_lock + + g_router_lock.obj = router_lock + from rpyc.utils.server import ThreadedServer - t = ThreadedServer(ModelRpcServer(), port=port, protocol_config={"allow_pickle": True}) + t = ThreadedServer(ModelRpcServer(args, info_queue, mem_queue), port=port, protocol_config={"allow_pickle": True}) t.start() return -async def start_model_process(port, world_size): +async def start_model_process(args, port, world_size, info_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue): # 单卡时不使用 rpc if world_size == 1: - return ModelRpcClient(ModelRpcServer(), world_size) - - import multiprocessing + return ModelRpcClient(ModelRpcServer(args, info_queue, mem_queue), world_size) - proc = multiprocessing.Process(target=_init_env, args=(port,)) + proc = mp.Process(target=_init_env, args=(args, port, info_queue, mem_queue, router_lock)) proc.start() await asyncio.sleep(2) repeat_count = 0 diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index dd2e1e3a8..d25cd4a76 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -1,9 +1,12 @@ from .continues_batch.impl import ContinuesBatchQueue from .continues_batch.beam_impl import BeamContinuesBatchQueue from .splitfuse.impl import SplitFuseQueue +from .continues_batch.pd_decode_impl import ContinuesBatchQueueForPDDecode def build_req_queue(args, router): + if args.run_mode == "decode": + return ContinuesBatchQueueForPDDecode(args, router) if args.splitfuse_mode: return SplitFuseQueue(args, router) if args.beam_mode: diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index c84ac4f04..5aca1d131 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -5,6 +5,7 @@ from lightllm.utils.infer_utils import calculate_time from lightllm.server.io_struct import Batch, Req from lightllm.server.io_struct import ReqRunStatus, FinishStatus +from lightllm.common.basemodel.infer_lock import g_router_lock class BaseQueue: @@ -43,7 +44,9 @@ def is_busy(self): # 计算当前所有的token使用量, 如果使用了dynamic prompt cache, 使用的token量中不包含,cache tree 中未被引用的数据。 cur_all_used_tokens = self.router.get_used_tokens() # 判断当前服务是否处于token使用率过高的状态,过高的情况下,调度要偏向保守 - cur_token_ratio = cur_all_used_tokens / self.max_total_tokens + cur_token_ratio = ( + cur_all_used_tokens + self.router.shared_token_load.get_frozened_token_count() + ) / self.max_total_tokens is_busy = cur_token_ratio >= self.router_token_ratio return is_busy @@ -51,9 +54,20 @@ def generate_new_batch(self, current_batch: Batch): raise NotImplementedError() def calcu_batch_token_load(self, current_batch: Batch): + if current_batch is None: + return 0, self.router.shared_token_load.get_frozened_token_count() / self.max_total_tokens + else: + return self._calcu_batch_token_load_batch_not_none(current_batch) + + def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): raise NotImplementedError() - def update_token_load(self, current_batch: Batch): - if self.router.shared_token_load.need_update_dynamic_max_load(): - self.router.shared_token_load.set_dynamic_max_load(self.calcu_batch_token_load(current_batch)) + def update_token_load(self, current_batch: Batch, force_update=False): + if self.router.shared_token_load.need_update_dynamic_max_load() or force_update: + estimated_peak_token_count, dynamic_max_load = self.calcu_batch_token_load(current_batch) + token_ratio1 = self.router.get_used_tokens() / self.router.max_total_token_num + with g_router_lock.obj: + self.router.shared_token_load.set_current_load(token_ratio1) + self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count) + self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load) return diff --git a/lightllm/server/router/req_queue/continues_batch/beam_impl.py b/lightllm/server/router/req_queue/continues_batch/beam_impl.py index 8e28cf9f6..6df4bcec8 100644 --- a/lightllm/server/router/req_queue/continues_batch/beam_impl.py +++ b/lightllm/server/router/req_queue/continues_batch/beam_impl.py @@ -50,7 +50,9 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new new_batch_first_router_need_tokens += req.cur_output_len new_batch_first_router_need_tokens += req.input_len - ok_token_num = need_max_token_num < self.max_total_tokens + ok_token_num = ( + need_max_token_num + self.router.shared_token_load.get_frozened_token_count() < self.max_total_tokens + ) if req.req_status != ReqRunStatus.PAUSED_AND_OFFLOAD: ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) <= self.running_max_req_size @@ -64,7 +66,10 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new ok_prefill = new_batch_first_router_need_tokens <= self.batch_max_tokens if ok_token_num and ok_req_num and ok_prefill: - self.router.shared_token_load.set_dynamic_max_load(need_max_token_num / self.max_total_tokens) + self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num) + self.router.shared_token_load.set_dynamic_max_load( + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count()) / self.max_total_tokens + ) return True, new_batch_first_router_need_tokens else: return False, new_batch_first_router_need_tokens @@ -136,9 +141,7 @@ def _add_to_group(self, cur_group_reqs, req: Req): else: return False - def calcu_batch_token_load(self, current_batch: Batch): - if current_batch is None: - return 0.0 + def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): is_busy = self.is_busy() self._init_cache_list(current_batch, is_busy) self.cache_len_list.sort(key=lambda x: -x[1][1]) @@ -155,4 +158,7 @@ def calcu_batch_token_load(self, current_batch: Batch): assert cur_input_len - req.input_len >= 0 cumsum_len += cur_input_len - req.input_len # 减去共享的部分 need_max_token_num = max(need_max_token_num, cumsum_len + index * cur_ouput_len) - return need_max_token_num / self.max_total_tokens + return ( + need_max_token_num, + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count()) / self.max_total_tokens, + ) diff --git a/lightllm/server/router/req_queue/continues_batch/impl.py b/lightllm/server/router/req_queue/continues_batch/impl.py index b32e78bf4..8101f15dd 100644 --- a/lightllm/server/router/req_queue/continues_batch/impl.py +++ b/lightllm/server/router/req_queue/continues_batch/impl.py @@ -6,6 +6,7 @@ from lightllm.server.io_struct import Batch, Req from lightllm.server.io_struct import ReqRunStatus from lightllm.server.router.req_queue.base_queue import BaseQueue +from lightllm.common.basemodel.infer_lock import g_router_lock class ContinuesBatchQueue(BaseQueue): @@ -33,23 +34,31 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - ok_token_num = need_max_token_num < self.max_total_tokens - if req.req_status != ReqRunStatus.PAUSED_AND_OFFLOAD: - ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) <= self.running_max_req_size - else: - # 因为存在重复的项 - ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) - 1 <= self.running_max_req_size + with g_router_lock.obj: + ok_token_num = ( + need_max_token_num + self.router.shared_token_load.get_frozened_token_count() < self.max_total_tokens + ) + + if req.req_status != ReqRunStatus.PAUSED_AND_OFFLOAD: + ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) <= self.running_max_req_size + else: + # 因为存在重复的项 + ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) - 1 <= self.running_max_req_size - new_batch_first_router_need_tokens += req.get_first_router_need_tokens() - # prefill ok - ok_prefill = new_batch_first_router_need_tokens <= self.batch_max_tokens + new_batch_first_router_need_tokens += req.get_first_router_need_tokens() + # prefill ok + ok_prefill = new_batch_first_router_need_tokens <= self.batch_max_tokens - if ok_token_num and ok_req_num and ok_prefill: - self.router.shared_token_load.set_dynamic_max_load(need_max_token_num / self.max_total_tokens) - return True, new_batch_first_router_need_tokens - else: - return False, new_batch_first_router_need_tokens + if ok_token_num and ok_req_num and ok_prefill: + self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num) + self.router.shared_token_load.set_dynamic_max_load( + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count()) + / self.max_total_tokens + ) + return True, new_batch_first_router_need_tokens + else: + return False, new_batch_first_router_need_tokens # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch): @@ -88,9 +97,7 @@ def generate_new_batch(self, current_batch: Batch): else: return None - def calcu_batch_token_load(self, current_batch: Batch): - if current_batch is None: - return 0.0 + def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): is_busy = self.is_busy() self._init_cache_list(current_batch, is_busy) self.cache_len_list.sort(key=lambda x: -x[1]) @@ -99,4 +106,7 @@ def calcu_batch_token_load(self, current_batch: Batch): cum_run_len_array = np.cumsum(has_run_len_array) size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - return need_max_token_num / self.max_total_tokens + return ( + need_max_token_num, + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count()) / self.max_total_tokens, + ) diff --git a/lightllm/server/router/req_queue/continues_batch/pd_decode_impl.py b/lightllm/server/router/req_queue/continues_batch/pd_decode_impl.py new file mode 100644 index 000000000..39896a8f0 --- /dev/null +++ b/lightllm/server/router/req_queue/continues_batch/pd_decode_impl.py @@ -0,0 +1,65 @@ +import time +import uuid +import numpy as np +from typing import List +from lightllm.utils.infer_utils import calculate_time +from lightllm.server.io_struct import Batch, Req +from lightllm.server.io_struct import ReqRunStatus +from lightllm.server.router.req_queue.base_queue import BaseQueue +from lightllm.common.basemodel.infer_lock import g_router_lock + + +class ContinuesBatchQueueForPDDecode(BaseQueue): + def __init__(self, args, router) -> None: + super().__init__(args, router) + + def _init_cache_list(self, current_batch: Batch, is_busy): + if current_batch is not None: + self.cache_len_list = [ + req.get_tuple_tokens(is_busy, self.router_max_new_token_len) for req in current_batch.reqs + ] + else: + self.cache_len_list = [] + return + + # @calculate_time(show=True, min_cost_ms=10) + def generate_new_batch(self, current_batch: Batch): + # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 + exist_req_num = (0 if current_batch is None else len(current_batch.reqs)) + len(self.pause_req_dict) + req_is_full = exist_req_num >= self.running_max_req_size + if req_is_full: + return None + + can_run_list = [] + aborted_count = 0 + for req in self.waiting_req_list: + if req.finish_status.is_aborted() and req.req_status == ReqRunStatus.WAIT_IN_QUEUE: + # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. + # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token和管理req对象的泄漏 + aborted_count += 1 + continue + if exist_req_num + len(can_run_list) + 1 <= self.batch_max_tokens: + can_run_list.append(req) + else: + break + + if len(can_run_list) != 0: + new_batch = Batch(uuid.uuid4().hex, can_run_list) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch + else: + return None + + def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): + is_busy = self.is_busy() + self._init_cache_list(current_batch, is_busy) + self.cache_len_list.sort(key=lambda x: -x[1]) + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + return ( + need_max_token_num, + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count()) / self.max_total_tokens, + ) diff --git a/lightllm/server/router/req_queue/splitfuse/impl.py b/lightllm/server/router/req_queue/splitfuse/impl.py index 3f2f8f569..98aaa94be 100644 --- a/lightllm/server/router/req_queue/splitfuse/impl.py +++ b/lightllm/server/router/req_queue/splitfuse/impl.py @@ -32,7 +32,9 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - ok_token_num = need_max_token_num < self.max_total_tokens + ok_token_num = ( + need_max_token_num + self.router.shared_token_load.get_frozened_token_count() < self.max_total_tokens + ) if req.req_status != ReqRunStatus.PAUSED_AND_OFFLOAD: ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) <= self.running_max_req_size @@ -44,7 +46,10 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens ok_splitfuse_decode = new_batch_first_router_need_tokens <= self.batch_max_tokens if ok_token_num and ok_req_num and ok_splitfuse_decode: - self.router.shared_token_load.set_dynamic_max_load(need_max_token_num / self.max_total_tokens) + self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num) + self.router.shared_token_load.set_dynamic_max_load( + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count()) / self.max_total_tokens + ) return True, new_batch_first_router_need_tokens else: return False, new_batch_first_router_need_tokens @@ -82,7 +87,7 @@ def generate_new_batch(self, current_batch: Batch): self.pause_req_dict.pop(req.request_id) else: break - + if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().hex, can_run_list) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] @@ -90,9 +95,7 @@ def generate_new_batch(self, current_batch: Batch): else: return None - def calcu_batch_token_load(self, current_batch: Batch): - if current_batch is None: - return 0.0 + def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): is_busy = self.is_busy() self._init_cache_list(current_batch, is_busy) self.cache_len_list.sort(key=lambda x: -x[1]) @@ -101,4 +104,7 @@ def calcu_batch_token_load(self, current_batch: Batch): cum_run_len_array = np.cumsum(has_run_len_array) size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - return need_max_token_num / self.max_total_tokens + return ( + need_max_token_num, + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count()) / self.max_total_tokens, + ) diff --git a/lightllm/server/router/token_load.py b/lightllm/server/router/token_load.py index a08272769..0b599e98a 100644 --- a/lightllm/server/router/token_load.py +++ b/lightllm/server/router/token_load.py @@ -4,39 +4,84 @@ class TokenLoad: - def __init__(self, name) -> None: - self.shared_token_load = SharedArray(name, shape=(3,), dtype=np.float64) + def __init__(self, name, dp_size) -> None: + # 为数据并行保留的接口 + self.dp_size = dp_size + self.shared_token_load = SharedArray( + name, + shape=( + self.dp_size, + 3, + ), + dtype=np.float64, + ) + # 用来保存调度需要使用到的一些信息 + self.shared_token_infos = SharedArray( + f"{name}_ext_infos", + shape=( + self.dp_size, + 2, + ), + dtype=np.int64, + ) self.last_dynamic_max_load_update_time = time.time() + # 记录系统调度器估计的峰值token使用量 + def set_estimated_peak_token_count(self, obj: int, index: int = 0): + self.shared_token_infos.arr[index, 0] = obj + self.last_dynamic_max_load_update_time = time.time() + return + + def add_estimated_peak_token_count(self, value: int, index: int = 0): + self.shared_token_infos.arr[index, 0] += value + self.last_dynamic_max_load_update_time = time.time() + return + + def get_estimated_peak_token_count(self, index: int = 0) -> int: + return self.shared_token_infos.arr[index, 0] + + # 记录系统被临时固定的不能被使用的token数,主要在于 pd 分离的模式下 + # 推理系统需要在 kv 传输时临时固定一些 token, 防止调度系统估计失误,导致调度问题 + def set_frozened_token_count(self, obj: int, index: int = 0): + self.shared_token_infos.arr[index, 1] = obj + return + + def get_frozened_token_count(self, index: int = 0) -> int: + return self.shared_token_infos.arr[index, 1] + + def add_frozened_token_count(self, value: int, index: int = 0): + self.shared_token_infos.arr[index, 1] += value + return + # current_load 当前使用token量,估计的负载 - def set_current_load(self, value): - self.shared_token_load.arr[0] = value + def set_current_load(self, value, index: int = 0): + self.shared_token_load.arr[index, 0] = value return - def get_current_load(self): - return self.shared_token_load.arr[0] + def get_current_load(self, index: int = 0): + return self.shared_token_load.arr[index, 0] # logical_max_load 朴素估计的负载,简单将当前请求的输入和输出长度想加得到, 目前已未使用,其值与dynamic_max_load一样 - def set_logical_max_load(self, value): - self.shared_token_load.arr[1] = value + def set_logical_max_load(self, value, index: int = 0): + self.shared_token_load.arr[index, 1] = value return - def get_logical_max_load(self): - return self.shared_token_load.arr[1] + def get_logical_max_load(self, index: int = 0): + return self.shared_token_load.arr[index, 1] # dynamic_max_load 动态估计的最大负载,考虑请求中途退出的情况,估计的最大token使用量 - def set_dynamic_max_load(self, value): - self.shared_token_load.arr[2] = value - self.set_logical_max_load(value) + def set_dynamic_max_load(self, value, index: int = 0): + self.shared_token_load.arr[index, 2] = value + self.set_logical_max_load(value, index=index) self.last_dynamic_max_load_update_time = time.time() return - def get_dynamic_max_load(self): - return self.shared_token_load.arr[2] + def get_dynamic_max_load(self, index: int = 0): + return self.shared_token_load.arr[index, 2] - def need_update_dynamic_max_load(self): - # 5s 需要进行一次更新 - if time.time() - self.last_dynamic_max_load_update_time >= 5.0: + def need_update_dynamic_max_load(self, index: int = 0): + # 3s 需要进行一次更新 + if time.time() - self.last_dynamic_max_load_update_time >= 3.0: return True else: return False diff --git a/lightllm/server/sampling_params.py b/lightllm/server/sampling_params.py index be7c538f5..2dd4d3cc6 100644 --- a/lightllm/server/sampling_params.py +++ b/lightllm/server/sampling_params.py @@ -36,6 +36,10 @@ def __init__( # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--simple_constraint_mode" started server. allowed_token_ids: Optional[List[int]] = None, + # p d mode used params + group_request_id: Optional[int] = None, + # move kv to deocde node, only used in pd mode + move_kv_to_decode_node: Optional[dict] = None, ) -> None: self.best_of = best_of self.n = n @@ -56,6 +60,8 @@ def __init__( self.print_eos_token = print_eos_token self.regular_constraint = regular_constraint self.allowed_token_ids = allowed_token_ids + self.group_request_id = group_request_id + self.move_kv_to_decode_node = move_kv_to_decode_node if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -134,6 +140,12 @@ def verify(self): except Exception as e: raise ValueError(f"regular_expression '{self.regular_constraint}' has parse_pattern_error: {str(e)}") + if not (self.group_request_id is None or isinstance(self.group_request_id, int)): + raise ValueError(f"group_request_id must be None or int ,but get {self.group_request_id}") + + if not (self.move_kv_to_decode_node is None or isinstance(self.move_kv_to_decode_node, dict)): + raise ValueError(f"move_kv_to_decode_node must be None or dict, but get {self.move_kv_to_decode_node}") + self._verify_stop_sentences() self._verify_allowed_token_ids() @@ -205,4 +217,10 @@ def to_dict(self): ret["input_penalty"] = self.input_penalty ret["regular_constraint"] = self.regular_constraint ret["allowed_token_ids"] = self.allowed_token_ids + ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node + return ret + + def to_origin_dict(self): + ret = self.to_dict() + ret["group_request_id"] = self.group_request_id return ret diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index 590a14b2a..91bc5b2df 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -3,26 +3,30 @@ from lightllm.server.sampling_params import SamplingParams from lightllm.server.multimodal_params import MultimodalParams from lightllm.server.httpserver.manager import HttpServerManager -from fastapi.responses import Response +from fastapi import Request +from lightllm.server.req_id_generator import ReqIDGenerator from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) -async def health_check(httpserver_manager: HttpServerManager, g_id_gen, request): +_g_health_req_id_gen = ReqIDGenerator() +_g_health_req_id_gen.generate_id() + + +async def health_check(args, httpserver_manager: HttpServerManager, request: Request): try: request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}} prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] sampling_params = SamplingParams(**sample_params_dict) sampling_params.verify() + if args.run_mode in ["prefill", "decode"]: + sampling_params.group_request_id = -_g_health_req_id_gen.generate_id() # health monitor 的 id 是负的 multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) - group_request_id = g_id_gen.generate_id() - results_generator = httpserver_manager.generate( - prompt, sampling_params, group_request_id, multimodal_params, request=request - ) + results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request) async for _, _, _, _ in results_generator: pass return True diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index b58e0b093..e2c2ea6c6 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -1,9 +1,13 @@ import socket +import subprocess +from lightllm.utils.log_utils import init_logger +logger = init_logger(__name__) -def alloc_can_use_network_port(num=3, used_nccl_ports=None): + +def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000): port_list = [] - for port in range(10000, 65536): + for port in range(from_port_num, 65536): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: result = s.connect_ex(("localhost", port)) if result != 0 and port not in used_nccl_ports: @@ -12,3 +16,33 @@ def alloc_can_use_network_port(num=3, used_nccl_ports=None): if len(port_list) == num: return port_list return None + + +def alloc_can_use_port(min_port, max_port): + port_list = [] + for port in range(min_port, max_port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + result = s.connect_ex(("localhost", port)) + if result != 0: + port_list.append(port) + return port_list + + +def find_available_port(start_port, end_port): + for port in range(start_port, end_port + 1): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + result = sock.connect_ex(("localhost", port)) + if result != 0: + return port + return None + + +def get_hostname_ip(): + try: + result = subprocess.run(["hostname", "-i"], capture_output=True, text=True, check=True) + result = result.stdout.strip() + logger.info(f"get hostname ip {result}") + return result + except subprocess.CalledProcessError as e: + logger.exception(f"Error executing command: {e}") + return None diff --git a/lightllm/utils/retry_utils.py b/lightllm/utils/retry_utils.py new file mode 100644 index 000000000..6a3bb9403 --- /dev/null +++ b/lightllm/utils/retry_utils.py @@ -0,0 +1,31 @@ +import time +import functools +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def retry(max_attempts=3, wait_time=1): + """ + 被修饰的函数调用失败需要自己抛异常 + :param max_attempts: 最大重试次数 + :param wait_time: 每次重试之间的等待时间(秒) + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + attempts = 0 + while attempts < max_attempts: + try: + return func(*args, **kwargs) + except Exception as e: + attempts += 1 + logger.info(f"try {func.__name__} {attempts}/{max_attempts} fail: {str(e)}") + if attempts < max_attempts: + time.sleep(wait_time) + raise Exception(f"{func.__name__} try all failed") + + return wrapper + + return decorator diff --git a/lightllm/utils/statics_utils.py b/lightllm/utils/statics_utils.py new file mode 100644 index 000000000..9d7aa1032 --- /dev/null +++ b/lightllm/utils/statics_utils.py @@ -0,0 +1,13 @@ +class MovingAverage: + def __init__(self): + self.total = 0.0 + self.count = 0 + + def add(self, value): + self.total += value + self.count += 1 + + def average(self): + if self.count == 0: + return 0.0 + return self.total / self.count diff --git a/requirements.txt b/requirements.txt index af973b2cd..b69e98670 100644 --- a/requirements.txt +++ b/requirements.txt @@ -80,3 +80,4 @@ psutil==5.9.4 prometheus_client==0.20.0 outlines==0.0.46 cchardet==2.1.7 +ujson==5.10.0 \ No newline at end of file