diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 60a3978f9abd7..0a759d303238b 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -13,8 +13,6 @@ $python_executable -m pip install -r requirements-cuda.txt # Limit the number of parallel jobs to avoid OOM export MAX_JOBS=1 -# Make sure punica is built for the release (for LoRA) -export VLLM_INSTALL_PUNICA_KERNELS=1 # Make sure release wheels are built for the following architectures export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" # Build diff --git a/CMakeLists.txt b/CMakeLists.txt index bf00a36edc500..7ca6513f3f1d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,61 +222,7 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -# -# _punica_C extension -# - -set(VLLM_PUNICA_EXT_SRC - "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cu" - "csrc/punica/torch_bindings.cpp") - -# -# Copy GPU compilation flags+update for punica -# -set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS}) -list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS - "-D__CUDA_NO_HALF_OPERATORS__" - "-D__CUDA_NO_HALF_CONVERSIONS__" - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" - "-D__CUDA_NO_HALF2_OPERATORS__") - -# -# Filter out CUDA architectures < 8.0 for punica. -# -if (${VLLM_GPU_LANG} STREQUAL "CUDA") - set(VLLM_PUNICA_GPU_ARCHES) - foreach(ARCH ${VLLM_GPU_ARCHES}) - string_to_ver(CODE_VER ${ARCH}) - if (CODE_VER GREATER_EQUAL 8.0) - list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH}) - endif() - endforeach() - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -elseif(${VLLM_GPU_LANG} STREQUAL "HIP") - set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -endif() -if (VLLM_PUNICA_GPU_ARCHES) - define_gpu_extension_target( - _punica_C - DESTINATION vllm - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_PUNICA_EXT_SRC} - COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} - ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} - USE_SABI 3 - WITH_SOABI) -else() - message(WARNING "Unable to create _punica_C target because none of the " - "requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0") -endif() # # Add the `default` target which detects which extensions should be @@ -300,12 +246,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) - # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or - # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and - # there are supported target arches. - if (VLLM_PUNICA_GPU_ARCHES AND - (ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS)) - message(STATUS "Enabling punica extension.") - add_dependencies(default _punica_C) - endif() endif() diff --git a/Dockerfile b/Dockerfile index b9a56e67e8d7b..db4453ab0efc9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -88,8 +88,6 @@ ENV MAX_JOBS=${max_jobs} # number of threads used by nvcc ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads -# make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 ARG buildkite_commit ENV BUILDKITE_COMMIT=${buildkite_commit} diff --git a/Dockerfile.rocm b/Dockerfile.rocm index ff39791456398..a620274cee4e1 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -138,8 +138,7 @@ RUN case "$(which python3)" in \ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install --upgrade numba scipy huggingface-hub[cli] -# Make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 + # Workaround for ray >= 2.10.0 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 # Silences the HF Tokenizers warning diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE deleted file mode 100644 index a46e2cdcadf7d..0000000000000 --- a/csrc/punica/LICENSE +++ /dev/null @@ -1,217 +0,0 @@ -Contains code from https://github.com/punica-ai/punica - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. - - -Apache-2.0 -* third_party/nvbench (with LLVM exception) -* third_party/flashinfer - -BSD-3-Clause: -* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu deleted file mode 100644 index 86846c274c90f..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu deleted file mode 100644 index de39c3121f5d3..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h deleted file mode 100644 index 2c8d007d8719f..0000000000000 --- a/csrc/punica/bgmv/bgmv_config.h +++ /dev/null @@ -1,218 +0,0 @@ -#pragma once - -template -void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale); - -// clang-format off - -#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, narrow, 128) \ - f(in_T, out_T, W_T, narrow, 256) \ - f(in_T, out_T, W_T, narrow, 512) \ - f(in_T, out_T, W_T, narrow, 640) \ - f(in_T, out_T, W_T, narrow, 768) \ - f(in_T, out_T, W_T, narrow, 896) \ - f(in_T, out_T, W_T, narrow, 1024) \ - f(in_T, out_T, W_T, narrow, 1152) \ - f(in_T, out_T, W_T, narrow, 1216) \ - f(in_T, out_T, W_T, narrow, 1280) \ - f(in_T, out_T, W_T, narrow, 1536) \ - f(in_T, out_T, W_T, narrow, 1664) \ - f(in_T, out_T, W_T, narrow, 1728) \ - f(in_T, out_T, W_T, narrow, 1792) \ - f(in_T, out_T, W_T, narrow, 2048) \ - f(in_T, out_T, W_T, narrow, 2240) \ - f(in_T, out_T, W_T, narrow, 2304) \ - f(in_T, out_T, W_T, narrow, 2368) \ - f(in_T, out_T, W_T, narrow, 2432) \ - f(in_T, out_T, W_T, narrow, 2560) \ - f(in_T, out_T, W_T, narrow, 2752) \ - f(in_T, out_T, W_T, narrow, 2816) \ - f(in_T, out_T, W_T, narrow, 3072) \ - f(in_T, out_T, W_T, narrow, 3328) \ - f(in_T, out_T, W_T, narrow, 3456) \ - f(in_T, out_T, W_T, narrow, 3584) \ - f(in_T, out_T, W_T, narrow, 3712) \ - f(in_T, out_T, W_T, narrow, 4096) \ - f(in_T, out_T, W_T, narrow, 4480) \ - f(in_T, out_T, W_T, narrow, 4608) \ - f(in_T, out_T, W_T, narrow, 4736) \ - f(in_T, out_T, W_T, narrow, 4864) \ - f(in_T, out_T, W_T, narrow, 5120) \ - f(in_T, out_T, W_T, narrow, 5504) \ - f(in_T, out_T, W_T, narrow, 5632) \ - f(in_T, out_T, W_T, narrow, 5888) \ - f(in_T, out_T, W_T, narrow, 6144) \ - f(in_T, out_T, W_T, narrow, 6400) \ - f(in_T, out_T, W_T, narrow, 6848) \ - f(in_T, out_T, W_T, narrow, 6912) \ - f(in_T, out_T, W_T, narrow, 7168) \ - f(in_T, out_T, W_T, narrow, 7424) \ - f(in_T, out_T, W_T, narrow, 8192) \ - f(in_T, out_T, W_T, narrow, 8960) \ - f(in_T, out_T, W_T, narrow, 9216) \ - f(in_T, out_T, W_T, narrow, 9472) \ - f(in_T, out_T, W_T, narrow, 10240) \ - f(in_T, out_T, W_T, narrow, 11008) \ - f(in_T, out_T, W_T, narrow, 11264) \ - f(in_T, out_T, W_T, narrow, 12288) \ - f(in_T, out_T, W_T, narrow, 13696) \ - f(in_T, out_T, W_T, narrow, 13824) \ - f(in_T, out_T, W_T, narrow, 14336) \ - f(in_T, out_T, W_T, narrow, 14784) \ - f(in_T, out_T, W_T, narrow, 14848) \ - f(in_T, out_T, W_T, narrow, 15360) \ - f(in_T, out_T, W_T, narrow, 16384) \ - f(in_T, out_T, W_T, narrow, 18944) \ - f(in_T, out_T, W_T, narrow, 20480) \ - f(in_T, out_T, W_T, narrow, 22016) \ - f(in_T, out_T, W_T, narrow, 22528) \ - f(in_T, out_T, W_T, narrow, 24576) \ - f(in_T, out_T, W_T, narrow, 27392) \ - f(in_T, out_T, W_T, narrow, 27648) \ - f(in_T, out_T, W_T, narrow, 28672) \ - f(in_T, out_T, W_T, narrow, 29568) \ - f(in_T, out_T, W_T, narrow, 29696) \ - f(in_T, out_T, W_T, narrow, 32000) \ - f(in_T, out_T, W_T, narrow, 32256) \ - f(in_T, out_T, W_T, narrow, 32512) \ - f(in_T, out_T, W_T, narrow, 32768) \ - f(in_T, out_T, W_T, narrow, 33024) \ - f(in_T, out_T, W_T, narrow, 36864) \ - f(in_T, out_T, W_T, narrow, 43264) \ - f(in_T, out_T, W_T, narrow, 49152) \ - f(in_T, out_T, W_T, narrow, 49408) \ - f(in_T, out_T, W_T, narrow, 60544) \ - f(in_T, out_T, W_T, narrow, 60672) \ - f(in_T, out_T, W_T, narrow, 64000) \ - f(in_T, out_T, W_T, narrow, 64256) \ - f(in_T, out_T, W_T, narrow, 64512) \ - f(in_T, out_T, W_T, narrow, 102400) \ - f(in_T, out_T, W_T, narrow, 102656) \ - f(in_T, out_T, W_T, narrow, 102912) \ - f(in_T, out_T, W_T, narrow, 128000) \ - f(in_T, out_T, W_T, narrow, 128256) \ - f(in_T, out_T, W_T, narrow, 128512) \ - - -// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA -// and vllm/tests/lora/test_punica.py - -// Used for defining kernels going from the variety of -// dim in to the narrow dim out - // Using it for the fully sharded column - // parallel LoRA A which splits the rank dim -#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, 128, narrow) \ - f(in_T, out_T, W_T, 256, narrow) \ - f(in_T, out_T, W_T, 512, narrow) \ - f(in_T, out_T, W_T, 640, narrow) \ - f(in_T, out_T, W_T, 768, narrow) \ - f(in_T, out_T, W_T, 896, narrow) \ - f(in_T, out_T, W_T, 1024, narrow) \ - f(in_T, out_T, W_T, 1152, narrow) \ - f(in_T, out_T, W_T, 1216, narrow) \ - f(in_T, out_T, W_T, 1280, narrow) \ - f(in_T, out_T, W_T, 1536, narrow) \ - f(in_T, out_T, W_T, 1664, narrow) \ - f(in_T, out_T, W_T, 1728, narrow) \ - f(in_T, out_T, W_T, 1792, narrow) \ - f(in_T, out_T, W_T, 2048, narrow) \ - f(in_T, out_T, W_T, 2240, narrow) \ - f(in_T, out_T, W_T, 2304, narrow) \ - f(in_T, out_T, W_T, 2368, narrow) \ - f(in_T, out_T, W_T, 2432, narrow) \ - f(in_T, out_T, W_T, 2560, narrow) \ - f(in_T, out_T, W_T, 2752, narrow) \ - f(in_T, out_T, W_T, 2816, narrow) \ - f(in_T, out_T, W_T, 3072, narrow) \ - f(in_T, out_T, W_T, 3328, narrow) \ - f(in_T, out_T, W_T, 3456, narrow) \ - f(in_T, out_T, W_T, 3584, narrow) \ - f(in_T, out_T, W_T, 3712, narrow) \ - f(in_T, out_T, W_T, 4096, narrow) \ - f(in_T, out_T, W_T, 4480, narrow) \ - f(in_T, out_T, W_T, 4608, narrow) \ - f(in_T, out_T, W_T, 4736, narrow) \ - f(in_T, out_T, W_T, 4864, narrow) \ - f(in_T, out_T, W_T, 5120, narrow) \ - f(in_T, out_T, W_T, 5504, narrow) \ - f(in_T, out_T, W_T, 5632, narrow) \ - f(in_T, out_T, W_T, 5888, narrow) \ - f(in_T, out_T, W_T, 6144, narrow) \ - f(in_T, out_T, W_T, 6400, narrow) \ - f(in_T, out_T, W_T, 6848, narrow) \ - f(in_T, out_T, W_T, 6912, narrow) \ - f(in_T, out_T, W_T, 7168, narrow) \ - f(in_T, out_T, W_T, 7424, narrow) \ - f(in_T, out_T, W_T, 8192, narrow) \ - f(in_T, out_T, W_T, 8960, narrow) \ - f(in_T, out_T, W_T, 9216, narrow) \ - f(in_T, out_T, W_T, 9472, narrow) \ - f(in_T, out_T, W_T, 10240, narrow) \ - f(in_T, out_T, W_T, 11008, narrow) \ - f(in_T, out_T, W_T, 11264, narrow) \ - f(in_T, out_T, W_T, 12288, narrow) \ - f(in_T, out_T, W_T, 13696, narrow) \ - f(in_T, out_T, W_T, 13824, narrow) \ - f(in_T, out_T, W_T, 14336, narrow) \ - f(in_T, out_T, W_T, 14784, narrow) \ - f(in_T, out_T, W_T, 14848, narrow) \ - f(in_T, out_T, W_T, 15360, narrow) \ - f(in_T, out_T, W_T, 16384, narrow) \ - f(in_T, out_T, W_T, 18944, narrow) \ - f(in_T, out_T, W_T, 20480, narrow) \ - f(in_T, out_T, W_T, 22016, narrow) \ - f(in_T, out_T, W_T, 22528, narrow) \ - f(in_T, out_T, W_T, 24576, narrow) \ - f(in_T, out_T, W_T, 27392, narrow) \ - f(in_T, out_T, W_T, 27648, narrow) \ - f(in_T, out_T, W_T, 28672, narrow) \ - f(in_T, out_T, W_T, 29568, narrow) \ - f(in_T, out_T, W_T, 29696, narrow) \ - f(in_T, out_T, W_T, 32000, narrow) \ - f(in_T, out_T, W_T, 32256, narrow) \ - f(in_T, out_T, W_T, 32512, narrow) \ - f(in_T, out_T, W_T, 32768, narrow) \ - f(in_T, out_T, W_T, 33024, narrow) \ - f(in_T, out_T, W_T, 36864, narrow) \ - f(in_T, out_T, W_T, 43264, narrow) \ - f(in_T, out_T, W_T, 49152, narrow) \ - f(in_T, out_T, W_T, 49408, narrow) \ - f(in_T, out_T, W_T, 60544, narrow) \ - f(in_T, out_T, W_T, 60672, narrow) \ - f(in_T, out_T, W_T, 64000, narrow) \ - f(in_T, out_T, W_T, 64256, narrow) \ - f(in_T, out_T, W_T, 64512, narrow) \ - f(in_T, out_T, W_T, 102400, narrow) \ - f(in_T, out_T, W_T, 102656, narrow) \ - f(in_T, out_T, W_T, 102912, narrow) \ - f(in_T, out_T, W_T, 128000, narrow) \ - f(in_T, out_T, W_T, 128256, narrow) \ - f(in_T, out_T, W_T, 128512, narrow) \ -// Keep above in sync with vllm/lora/layers::SamplerWithLoRA - - -// Keep this in sync with vllm/config::LoRAConfig -#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) - - -#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \ - f(in_T, out_T, W_T, 8, 64) \ - f(in_T, out_T, W_T, 16, 64) \ - f(in_T, out_T, W_T, 32, 64) \ - f(in_T, out_T, W_T, 64, 64) - -// clang-format on diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu deleted file mode 100644 index d225a1eaa82b0..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu deleted file mode 100644 index b37d288a75561..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu deleted file mode 100644 index a1ab2deecbabf..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu deleted file mode 100644 index 0b35bf5699898..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh deleted file mode 100644 index 8a3b8403b4a6f..0000000000000 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ /dev/null @@ -1,451 +0,0 @@ -#pragma once - -#include -#ifndef USE_ROCM -#include -#else -#include -#endif -#ifndef USE_ROCM -#include -#endif -#include -#include -#include - -#include "vec_dtypes.cuh" - -namespace cg = cooperative_groups; - -#ifdef USE_ROCM -template -__host__ __device__ -inline void* memcpy_blocking(void *dst, const void *src) { - // Does not handle the case of long datatypes - char *d = reinterpret_cast(dst); - const char *s = reinterpret_cast(src); - size_t i = 0; -#pragma unroll - for (i = 0; i < len; ++i) { - d[i] = s[i]; - } - return dst; -} -#endif - -#ifndef USE_ROCM - -// nthrs = (32, 4) -template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - auto block = cg::this_thread_block(); - size_t j = blockIdx.x; - constexpr size_t num_pipeline_stages = 2; - constexpr size_t tile_size = tx * ty * vec_size; - __shared__ W_T W_shared[num_pipeline_stages * tile_size]; - __shared__ in_T X_shared[num_pipeline_stages * tile_size]; - __shared__ float y_warpwise[ty]; - - size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; - size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; - auto pipe = cuda::make_pipeline(); - - // pipeline load W/X and compute WX; - pipe.producer_acquire(); - cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(W_copy_size), pipe); - cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(X_copy_size), pipe); - pipe.producer_commit(); - size_t copy_idx, compute_idx; - float y = 0.f; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; - ++tile_idx) { - copy_idx = tile_idx % num_pipeline_stages; - // pipeline stage: async copy W fragment - pipe.producer_acquire(); - if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { - cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(W_copy_size), pipe); - cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(X_copy_size), pipe); - } - pipe.producer_commit(); - - compute_idx = (tile_idx - 1) % num_pipeline_stages; - // pipeline stage: compute WX - pipe.consumer_wait(); - block.sync(); - x_vec.load(X_shared + X_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W_shared + W_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += float(w_vec[i]) * float(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - y_warpwise[threadIdx.y] = sum; - block.sync(); -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y += y_warpwise[i]; - } - - block.sync(); - pipe.consumer_release(); - } - - compute_idx = (tile_idx - 1) % num_pipeline_stages; - // final pipeline stage - pipe.consumer_wait(); - block.sync(); - x_vec.load(X_shared + X_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W_shared + W_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += float(w_vec[i]) * float(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - y_warpwise[threadIdx.y] = - ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) - ? sum - : 0.f; - block.sync(); -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y += y_warpwise[i]; - } - - block.sync(); - pipe.consumer_release(); - - // write Y; - if (block.thread_rank() == 0) { - Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); - } -} - -#else - -template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - size_t j = blockIdx.x; - constexpr size_t tile_size = tx * ty * vec_size; - constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; - __shared__ float y_warpwise[ty]; - - float y = 0; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - x_vec.load(X + (batch_idx * feat_in) + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - } - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += VLLM_SHFL_DOWN_SYNC(sum, offset); - } - - __syncthreads(); - - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - y += sum; - } - } - - if (threadIdx.x == 0) { - y_warpwise[threadIdx.y] = y; - } - __syncthreads(); - - float y_write = 0.f; -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y_write += y_warpwise[i]; - } - - // write Y; - if (threadIdx.x == 0 && threadIdx.y == 0) { - size_t y_idx = batch_idx * full_y_size + y_offset + j; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); - } -} - -#endif - -// nthrs = (2, 16, 4) -template -__global__ void -bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - - if (idx < 0) { - return; - } - - auto block = cg::this_thread_block(); - size_t tile_idx = blockIdx.x; - - // load X; - vec_t x_vec; - x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); - - // load W; - vec_t w_vec; - w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + - block.thread_rank() * vec_size); - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { -#ifndef USE_ROCM - sum += float(w_vec[i]) * float(x_vec[i]) * scale; -#else - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; -#endif - } - - cg::thread_block_tile g = cg::tiled_partition(block); -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += g.shfl_down(sum, offset); - } - sum = g.shfl(sum, 0); - - if (threadIdx.x == 0) { -#ifndef USE_ROCM - Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y] += static_cast(sum); -#else - size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); -#endif - } -} - -template -void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale) { - constexpr size_t vec_size = 8; - constexpr int tz = 4; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if constexpr (feat_in <= feat_out) { - static_assert(feat_in % vec_size == 0); - constexpr int tx = feat_in / vec_size; - - static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || - (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || - (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); - - if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { - constexpr int ty = 32 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { - constexpr int ty = 16 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else { - constexpr int ty = 8 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } - } else { -#ifndef USE_ROCM - static_assert(feat_in % (vec_size * 32) == 0 || - feat_in % (vec_size * 16) == 0 || - feat_in % (vec_size * 8) == 0); - - if constexpr (feat_in % (vec_size * 32) == 0) { - constexpr int tx = 32; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { - constexpr int tx = 32; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { - constexpr int tx = 16; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } -#else - constexpr size_t rocm_warp_size = warpSize; - -#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ - feat_in % (rocm_warp_size * vec_size_) == 0 - -#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ - if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ - constexpr size_t vec_size_shrink = vec_size_; \ - constexpr int tx = tx_; \ - constexpr int ty = ty_; \ - dim3 nblks(feat_out, batch_size); \ - dim3 nthrs(tx, ty); \ - bgmv_shrink_kernel \ - <<>>(Y, X, W, indicies, y_offset, \ - full_y_size, num_layers, layer_idx, \ - scale); \ - } - - static_assert(CHECK_INPUT_TILEABLE_BY(32) || - CHECK_INPUT_TILEABLE_BY(16) || - CHECK_INPUT_TILEABLE_BY( 8) || - CHECK_INPUT_TILEABLE_BY( 4) || - CHECK_INPUT_TILEABLE_BY( 2) || - CHECK_INPUT_TILEABLE_BY( 1)); - - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) - -#undef CHECK_INPUT_TILEABLE_BY -#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM -#endif - } -} - -#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ - template void bgmv_kernel( \ - out_T * __restrict__ Y, const in_T *__restrict__ X, \ - const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ - int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ - int64_t num_layers, int64_t layer_idx, float scale); - -#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ - INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) - -#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ - INST_BGMV(narrow, wide, in_T, out_T, W_T) \ - INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py deleted file mode 100644 index 972df5a7208c2..0000000000000 --- a/csrc/punica/bgmv/generator.py +++ /dev/null @@ -1,48 +0,0 @@ -DTYPES = ["fp16", "bf16", "fp32"] -DTYPE_MAP = { - "fp16": "nv_half", - "bf16": "nv_bfloat16", - "fp32": "float", -} - -TEMPLATE = """ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -""".lstrip() # noqa: E501 - -for input_dtype in DTYPES: - for output_dtype in DTYPES: - for weight_dtype in DTYPES: - if weight_dtype == "fp32": - # FP32 weights are not supported. - continue - if output_dtype == "fp32": - # LoRA A matrix. - if input_dtype != weight_dtype: - # NOTE(woosuk): While Punica supports the case where the - # input and weight dtypes are different, we only generate - # the kernels the same dtypes to reduce the binary size. - continue - elif input_dtype == "fp32": - # LoRA B matrix. - if output_dtype != weight_dtype: - # NOTE(woosuk): While Punica supports the case where the - # output and weight dtypes are different, we only generate - # the kernels the same dtypes to reduce the binary size. - continue - elif not (input_dtype == output_dtype == weight_dtype): - # NOTE(woosuk): While Punica supports mixed data types for - # input, output, and weight, we only generate the kernels with - # the same data types to reduce the binary size. - continue - - kernel_definition = TEMPLATE.format( - input_dtype=DTYPE_MAP[input_dtype], - output_dtype=DTYPE_MAP[output_dtype], - weight_dtype=DTYPE_MAP[weight_dtype]) - filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu" - with open(filename, "w") as f: - f.write(kernel_definition) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh deleted file mode 100644 index 2738892e6dc4a..0000000000000 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ /dev/null @@ -1,1325 +0,0 @@ -#ifndef VEC_DTYPES_CUH_ -#define VEC_DTYPES_CUH_ - -#ifdef FLASHINFER_USE_FP8 -#include -#endif -#include - -#include - -#include "../type_convert.h" -#include "../../cuda_compat.h" - -#define FLASHINFER_INLINE \ - inline __attribute__((always_inline)) __device__ __host__ - -template -struct vec_t { - FLASHINFER_INLINE float_t &operator[](size_t i); - FLASHINFER_INLINE const float_t &operator[](size_t i) const; - FLASHINFER_INLINE void fill(float_t val); - FLASHINFER_INLINE void load(const float_t *ptr); - FLASHINFER_INLINE void store(float_t *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src); - template - FLASHINFER_INLINE void cast_load(const T *ptr); - template - FLASHINFER_INLINE void cast_store(T *ptr) const; - FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); -}; - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = tgt_float_t(src[i]); - } -} - -template -FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, - vec_t &dst) { - if constexpr (std::is_same::value) { - dst.load(src_ptr); - } else { - vec_t tmp; - tmp.load(src_ptr); - dst.cast_from(tmp); - } -} - -template -FLASHINFER_INLINE void cast_store_impl(const vec_t &src, - tgt_float_t *dst_ptr) { - if constexpr (std::is_same::value) { - src.store(dst_ptr); - } else { - vec_t tmp; - tmp.cast_from(src); - tmp.store(dst_ptr); - } -} - -#ifdef FLASHINFER_USE_FP8 -/******************* vec_t<__nv_fp8_e4m3> *******************/ - -// __nv_fp8_e4m3 x 1 -template <> -struct vec_t<__nv_fp8_e4m3, 1> { - __nv_fp8_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( - __nv_fp8_e4m3 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *dst = *src; -} - -// __nv_fp8_e4m3 x 2 -template <> -struct vec_t<__nv_fp8_e4m3, 2> { - __nv_fp8x2_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { - data = *((__nv_fp8x2_e4m3 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( - __nv_fp8_e4m3 *ptr) const { - *((__nv_fp8x2_e4m3 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 4 - -template <> -struct vec_t<__nv_fp8_e4m3, 4> { - __nv_fp8x4_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { - data = *((__nv_fp8x4_e4m3 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( - __nv_fp8_e4m3 *ptr) const { - *((__nv_fp8x4_e4m3 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 8 - -template <> -struct vec_t<__nv_fp8_e4m3, 8> { - uint2 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { - ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( - __nv_fp8_e4m3 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 16 or more -template -struct vec_t<__nv_fp8_e4m3, vec_size> { - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)data)[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)data)[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t<__nv_fp8_e5m2> *******************/ - -// __nv_fp8_e5m2 x 1 -template <> -struct vec_t<__nv_fp8_e5m2, 1> { - __nv_fp8_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( - __nv_fp8_e5m2 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *dst = *src; -} - -// __nv_fp8_e5m2 x 2 -template <> -struct vec_t<__nv_fp8_e5m2, 2> { - __nv_fp8x2_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { - data = *((__nv_fp8x2_e5m2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( - __nv_fp8_e5m2 *ptr) const { - *((__nv_fp8x2_e5m2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 4 - -template <> -struct vec_t<__nv_fp8_e5m2, 4> { - __nv_fp8x4_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { - data = *((__nv_fp8x4_e5m2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( - __nv_fp8_e5m2 *ptr) const { - *((__nv_fp8x4_e5m2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 8 - -template <> -struct vec_t<__nv_fp8_e5m2, 8> { - uint2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { - ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( - __nv_fp8_e5m2 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 16 or more - -template -struct vec_t<__nv_fp8_e5m2, vec_size> { - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)data)[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)data)[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; -#endif - -/******************* vec_t *******************/ - -// half x 1 -template <> -struct vec_t { - half data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *dst = *src; -} - -// half x 2 -template <> -struct vec_t { - half2 data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { - data = make_half2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { - data = *((half2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { - *((half2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *((half2 *)dst) = *((half2 *)src); -} - -// half x 4 - -template <> -struct vec_t { - uint2 data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { - *(half2 *)(&data.x) = make_half2(val, val); - *(half2 *)(&data.y) = make_half2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *((uint2 *)dst) = *((uint2 *)src); -} - -// half x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)data)[i]; - } - FLASHINFER_INLINE void fill(half val) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - *(half2 *)(&(data[i].x)) = make_half2(val, val); - *(half2 *)(&(data[i].y)) = make_half2(val, val); - *(half2 *)(&(data[i].z)) = make_half2(val, val); - *(half2 *)(&(data[i].w)) = make_half2(val, val); - } - } - FLASHINFER_INLINE void load(const half *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(half *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// nv_bfloat16 x 1 -template <> -struct vec_t { - nv_bfloat16 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *dst = *src; -} - -// nv_bfloat16 x 2 -template <> -struct vec_t { - nv_bfloat162 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - data = make_bfloat162(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *((nv_bfloat162 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *((nv_bfloat162 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); -} - -// nv_bfloat16 x 4 - -template <> -struct vec_t { - uint2 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *((uint2 *)dst) = *((uint2 *)src); -} - -// nv_bfloat16 x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); - } - } - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// float x 1 - -template <> -struct vec_t { - float data; - - FLASHINFER_INLINE float &operator[](size_t i) { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); -}; - -FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } - -FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } - -FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } - -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { - *dst = *src; -} - -// float x 2 - -template <> -struct vec_t { - float2 data; - - FLASHINFER_INLINE float &operator[](size_t i) { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); -}; - -FLASHINFER_INLINE void vec_t::fill(float val) { - data = make_float2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const float *ptr) { - data = *((float2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(float *ptr) const { - *((float2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { - *((float2 *)dst) = *((float2 *)src); -} - -// float x 4 or more -template -struct vec_t { - float4 data[vec_size / 4]; - - FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(data))[i]; - } - FLASHINFER_INLINE void fill(float val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = make_float4(val, val, val, val); - } - } - FLASHINFER_INLINE void load(const float *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(float *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)dst)[i] = ((float4 *)src)[i]; - } - } -}; - -/******************* vec_t type cast *******************/ - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = half(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)(&dst.data))[i] = - __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = nv_bfloat16(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((nv_bfloat162 *)(&dst.data))[i] = - __float22bfloat162_rn(((float2 *)(&src.data))[i]); - } - } -} - -#ifdef FLASHINFER_USE_FP8 - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e4m3, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = - __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e4m3, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( - ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e5m2, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e5m2(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = - __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e5m2, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( - ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); - } - } -} - -#endif // FLASHINFER_USE_FP8 - -#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu deleted file mode 100644 index dd29820144b34..0000000000000 --- a/csrc/punica/punica_ops.cu +++ /dev/null @@ -1,569 +0,0 @@ -#include -#include -#include - -#include "type_convert.h" -#include "../cuda_compat.h" -#include "bgmv/bgmv_config.h" - - -//====== utils ====== - -inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, - const char *a_name, const char *b_name) { - TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", - a.dim(), " vs ", b.dim()); - for (int i = 0; i < a.dim(); ++i) { - TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, - ".size(", i, ")"); - } -} - -inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { - return (uint64_t(a) << 32) | uint64_t(b); -} - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -#define CHECK_DIM(d, x) \ - TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") - -#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) - -#define CHECK_EQ(a, b) \ - TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) - -//====== bgmv ====== - -template -inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, - const int64_t *lora_indices, - uint32_t in_features, uint32_t out_features, - int64_t y_offset, int64_t full_y_size, - int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale) { - // NOTE(woosuk): While Punica supports various combinations of input/output - // data types, we limit the supported data types to reduce the binary size. - constexpr bool is_input_float = std::is_same::value; - constexpr bool is_output_float = std::is_same::value; - if (is_input_float) { - if (!std::is_same::value) { - return false; - } - } else if (is_output_float) { - if (!std::is_same::value) { - return false; - } - } else if (!(std::is_same::value && - std::is_same::value)) { - return false; - } - - switch (pack_u32(in_features, out_features)) { -#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u32(feat_in, feat_out): \ - bgmv_kernel(Y, X, W, lora_indices, y_offset, \ - full_y_size, batch_size, num_layers, \ - layer_idx, scale); \ - break; -#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ - CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ - CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) - - FOR_BGMV_WIDE_NARROW(CASE, _, _, _) - FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _) -#undef CASE -#undef CASE_ONESIDE - default: - return false; - } - return true; -} - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(4, w); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t h_in = x.size(1); - int64_t h_out = y.size(1); - int64_t num_layers = w.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - bool ok = false; - if (h_in <= 128512 && h_out <= 128512) { - // TODO: See if we can get rid of this massive nested switch - switch (x.scalar_type()) { - case at::ScalarType::Half: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - default: - break; - } - } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); -} - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(4, w); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t num_layers = w.size(1); - int64_t full_y_size = y.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - bool ok = false; - if (h_in <= 128512 && h_out <= 128512) { - // TODO: See if we can get rid of this massive nested switch - switch (x.scalar_type()) { - case at::ScalarType::Half: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - default: - break; - } - } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h deleted file mode 100644 index 5d625d0564f75..0000000000000 --- a/csrc/punica/punica_ops.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale); - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset); diff --git a/csrc/punica/torch_bindings.cpp b/csrc/punica/torch_bindings.cpp deleted file mode 100644 index 894e229b6d9db..0000000000000 --- a/csrc/punica/torch_bindings.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "registration.h" -#include "punica_ops.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - m.def( - "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " - "layer_idx, float scale) -> ()"); - m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - - m.def( - "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," - "Tensor indicies, int layer_idx," - "float scale, int h_in, int h_out," - "int y_offset) -> ()"); - m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h deleted file mode 100644 index dff7ce49283d7..0000000000000 --- a/csrc/punica/type_convert.h +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ -#define CSRC__PUNICA__TYPE_CONVERT_H__ - -#ifndef USE_ROCM - -#include -#include - -#else - -#include -#include - -#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ - -typedef __half nv_half; -typedef __hip_bfloat16 nv_bfloat16; -typedef __hip_bfloat162 nv_bfloat162; - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { - return __hip_bfloat162{val, val}; -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { - return __hip_bfloat162{vall, valr}; -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T_dst convert_type(T_src val) { - return static_cast(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__half, float>(__half val) { - return __half2float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half convert_type(float val) { - return __float2half(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { - return __bfloat162float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 convert_type(float val) { - return __float2bfloat16(val); -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T vllm_add(T a, T b) { - return a + b; -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half vllm_add<__half>(__half a, __half b) { - return __hadd(a, b); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { - return __hadd(a, b); -} - -#undef __TYPE_CONVERT__HOST_DEVICE__ - -#endif // USE_ROCM - -#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index fe041e03a1b6c..0253717da3cda 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -66,7 +66,6 @@ You can also build and install vLLM from source: $ git clone https://github.com/vllm-project/vllm.git $ cd vllm - $ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability $ pip install -e . # This may take 5-10 minutes. .. tip:: diff --git a/setup.py b/setup.py index 72ef26f15e405..63c1f466d2910 100644 --- a/setup.py +++ b/setup.py @@ -181,9 +181,6 @@ def configure(self, ext: CMakeExtension) -> None: # match. cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] - if _install_punica(): - cmake_args += ['-DVLLM_INSTALL_PUNICA_KERNELS=ON'] - # # Setup parallelism and build tool # @@ -274,10 +271,6 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() -def _install_punica() -> bool: - return envs.VLLM_INSTALL_PUNICA_KERNELS - - def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -446,9 +439,6 @@ def _read_requirements(filename: str) -> List[str]: if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) - if _install_punica(): - ext_modules.append(CMakeExtension(name="vllm._punica_C")) - package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index e28f809309ec5..7212ba35fef0a 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -1,4 +1,5 @@ import gc +from unittest.mock import patch import pytest import torch @@ -6,10 +7,15 @@ import triton.language as tl from vllm.model_executor.layers.ops.sample import ( - MAX_TRITON_N_COLS, _uniform_to_exponential, get_num_triton_sampler_splits, - sample) + MAX_TRITON_N_COLS, + _sample_triton, + _uniform_to_exponential, + get_num_triton_sampler_splits, + sample, +) from vllm.model_executor.sampling_metadata import SamplingTensors from vllm.model_executor.utils import set_random_seed +from vllm.triton_utils.libentry import LibEntry SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 @@ -75,15 +81,20 @@ def test_sample_decoding_only(random_sampling, max_best_of, seeds = torch.randint(1, torch.iinfo(torch.long).max, (n_splits, bs), device="cuda").mul_(random_sampling_mask) - sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - _save_modified_probs=True) + #The current _sample_triton does not utilize the + # libentry decoration. The purpose of adding this patch is to test + # the correctness of libentry. + with patch("vllm.model_executor.layers.ops.sample._sample_triton", + LibEntry(_sample_triton)): + sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + _save_modified_probs=True) assert sampled_tokens.shape == (bs, max_best_of) for i in range(bs): assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) @@ -129,6 +140,7 @@ def test_sample_decoding_only(random_sampling, max_best_of, [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) def test_sample_prompt_logprobs(random_sampling, max_best_of, modify_greedy_probs, seed, vocab_size): + set_random_seed(seed) prompt_sizes = [16, 32, 64, 128] * 2 samples = 8 @@ -156,14 +168,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of, seeds = torch.randint(1, torch.iinfo(torch.long).max, (n_splits, samples), device="cuda").mul_(random_sampling_mask) - sampled_tokens, sampled_logprobs, _ = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=True) + #ditto + with patch("vllm.model_executor.layers.ops.sample._sample_triton", + LibEntry(_sample_triton)): + sampled_tokens, sampled_logprobs, _ = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=True) assert sampled_tokens.shape == (samples, max_best_of) assert sampled_logprobs.shape == (samples, max_best_of) for i, t in enumerate(sample_indices): diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 709246179bfe4..478bb86b78610 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files): expected_lora_output = [ "more important than knowledge.\nAuthor: Albert Einstein\n", "everyone else is already taken.\nAuthor: Oscar Wilde\n", - "so little time\nAuthor: Frank Zappa\n", + "so little time.\nAuthor: Frank Zappa\n", ] output1 = do_sample(llm, gemma_lora_files, lora_id=1) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 7207af6b1a4b3..6f33f56616fcd 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -26,7 +26,8 @@ VocabParallelEmbeddingWithLoRA) # yapf: enable from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, - PackedLoRALayerWeights, convert_mapping) + PackedLoRALayerWeights) +from vllm.lora.punica import PunicaWrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -47,6 +48,9 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +# We will launch different triton kernels between the prefill and decode +# stages, so we need to verify this. prefill stage(True) or decode stage(False) +STAGES = [True, False] def get_random_id_to_index(num_loras: int, @@ -182,10 +186,12 @@ def create_random_inputs( @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: +@pytest.mark.parametrize("stage", STAGES) +def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -204,7 +210,7 @@ def create_random_embedding_layer(): id_to_index = get_random_id_to_index(num_loras, max_loras) embedding, lora_embedding = create_random_embedding_layer() - + lora_embedding.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, layer=lora_embedding, @@ -217,12 +223,12 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) - lora_embedding.set_mapping(*mapping_info) lora_result = lora_embedding(torch.cat(inputs)) @@ -255,12 +261,12 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) - lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs)) @@ -278,11 +284,13 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +@pytest.mark.parametrize("stage", STAGES) def test_embeddings_with_new_embeddings(dist_init, num_loras, device, - vocab_size) -> None: + vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -318,6 +326,7 @@ def create_random_embedding_layer(): generate_embeddings_tensor=256, ) + lora_embedding.set_mapping(punica_wrapper) # All embeddings tensors have the same shape. embeddings_tensors = [ lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) @@ -334,8 +343,12 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) original_inputs = deepcopy(inputs) # Force some of the inputs to be in the extended embeddings range @@ -349,11 +362,6 @@ def create_random_embedding_layer(): (embedding_id + 1) * embeddings_tensor_len - 1) original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - lora_embedding.set_mapping(*mapping_info, ) - expanded_embedding.weight[vocab_size:vocab_size + (embeddings_tensor_len * max_loras)] = torch.cat(embeddings_tensors) @@ -390,15 +398,13 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - original_inputs = deepcopy(inputs) - - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) - lora_embedding.set_mapping(*mapping_info, ) - lora_result = lora_embedding(torch.cat(original_inputs)) expected_result = expanded_embedding(torch.cat(inputs)) @@ -413,11 +419,13 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -def test_lm_head_logits_processor(dist_init, num_loras, device, - vocab_size) -> None: +@pytest.mark.parametrize("stage", STAGES) +def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, + stage) -> None: torch.set_default_device(device) max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -443,7 +451,7 @@ def _pretest(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, logits_processor, lora_logits_processor = _pretest() - + lora_logits_processor.set_mapping(punica_wrapper) # NOTE: all the generated loras share the same embeddings tensor. lora_dict, _ = populate_loras( id_to_index, @@ -461,17 +469,17 @@ def _pretest(): input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - input_ = torch.rand(20, 1024) - mapping_info = convert_mapping( + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size, ) - lora_logits_processor.set_mapping(*mapping_info, ) + input_ = torch.rand(20, 1024) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), @@ -510,12 +518,16 @@ def _pretest(): input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - lora_logits_processor.set_mapping(*mapping_info, ) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), @@ -538,10 +550,12 @@ def _pretest(): @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("stage", STAGES) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device) -> None: + device, stage) -> None: torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -575,7 +589,7 @@ def create_random_linear_parallel_layer(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_random_linear_parallel_layer() - + lora_linear.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, layer=lora_linear, @@ -589,16 +603,16 @@ def create_random_linear_parallel_layer(): input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - mapping_info = convert_mapping( + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) - lora_linear.set_mapping(*mapping_info, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -628,11 +642,12 @@ def create_random_linear_parallel_layer(): input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size) - lora_linear.set_mapping(*mapping_info, ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] @@ -649,10 +664,12 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("stage", STAGES) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device) -> None: + device, stage) -> None: torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -707,7 +724,7 @@ class FakeConfig: id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_column_parallel_packed_layer() - + lora_linear.set_mapping(punica_wrapper) lora_dict, sublora_dict = populate_loras( id_to_index, layer=lora_linear, @@ -722,16 +739,17 @@ class FakeConfig: input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) - mapping_info = convert_mapping( + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) - lora_linear.set_mapping(*mapping_info) lora_result = lora_linear(torch.cat(inputs))[0] @@ -762,16 +780,18 @@ class FakeConfig: input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) - mapping_info = convert_mapping( + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) - lora_linear.set_mapping(*mapping_info) + # lora_linear.set_mapping(*mapping_info) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] @@ -803,7 +823,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.set_default_device(device) - + punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -825,6 +845,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, is_neox_style, ) lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) + lora_rope.set_mapping(punica_wrapper) lora_rope.create_lora_weights(max_loras, lora_config) linear_rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { @@ -840,6 +861,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, input_range=(0, lora_config.lora_extra_vocab_size), input_type=torch.float16, ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) long_lora_context = LongContextLoRAContext(list(scaling_factors), rotary_dim) @@ -854,7 +876,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, for i in range(len(scaling_factors)): long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( scaling_factors[i], 0) - mapping_info = convert_mapping( + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, @@ -862,7 +884,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, lora_config.lora_extra_vocab_size, long_lora_context=long_lora_context, ) - lora_rope.set_mapping(*mapping_info) + # lora_rope.set_mapping(*mapping_info) positions = torch.randint(0, max_position, (batch_size, seq_len)) query = torch.randn(batch_size, diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py deleted file mode 100644 index 3415d36b7e341..0000000000000 --- a/tests/lora/test_lora.py +++ /dev/null @@ -1,224 +0,0 @@ -import pytest -import torch - -from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice - -from .utils import DummyLoRAManager - -TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] -QKV_TENSOR_SIZES = [ - (8192, 1024, 1024), - (8192 // 8, 1024 // 8, 1024 // 8), - (4096, 4096, 4096), - (4096 // 2, 4096 // 2, 4096 // 2), -] -BATCH_SIZES = [8, 32, 256] -RANKS = [8] -DTYPES = [torch.float16] -TOLERANCES = { - torch.float16: (5e-3, 5e-3), - torch.bfloat16: (3e-2, 2e-2), -} - - -@pytest.mark.parametrize("m", TENSOR_SIZES) -@pytest.mark.parametrize("n", TENSOR_SIZES) -@pytest.mark.parametrize("k", BATCH_SIZES) -@pytest.mark.parametrize("rank", RANKS) -@pytest.mark.parametrize("dtype", DTYPES) -def test_apply_lora(m, n, k, rank, dtype) -> None: - manager = DummyLoRAManager() - - module_name = "module" - weight = torch.rand([m, n], device="cuda", dtype=dtype) - - manager.init_random_lora(module_name, weight, rank=rank) - lora = manager.get_module_lora(module_name) - - input = torch.rand(k, n, device="cuda", dtype=dtype) - expected = input @ lora.lora_a @ lora.lora_b * lora.scaling - - lora_a_stack = torch.zeros(8, - 1, - lora.lora_a.shape[1], - lora.lora_a.shape[0], - device="cuda", - dtype=dtype) - lora_b_stack = torch.zeros(8, - 1, - lora.lora_b.shape[1], - lora.lora_b.shape[0], - device="cuda", - dtype=dtype) - for i in range(lora_a_stack.shape[0]): - lora_a_stack[i][0] = lora.lora_a.T - lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T - - output = torch.zeros(k, m, device="cuda", dtype=dtype) - _apply_lora( - input, lora_a_stack, lora_b_stack, - torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"), - output) - - rtol, atol = TOLERANCES[dtype] - assert torch.allclose(expected, output, rtol=rtol, atol=atol) - - output[:] = 0 - _apply_lora(input, lora_a_stack, lora_b_stack, - torch.full((len(input), ), -1, device="cuda"), output) - assert torch.allclose(torch.zeros_like(output), output) - - manager.reset_lora() - - -@pytest.mark.parametrize("m", TENSOR_SIZES) -@pytest.mark.parametrize("n", TENSOR_SIZES) -@pytest.mark.parametrize("k", BATCH_SIZES) -@pytest.mark.parametrize("rank", RANKS) -@pytest.mark.parametrize("dtype", DTYPES) -def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: - if m % 2 != 0: - pytest.skip("m must be divisible by 2") - if m // 2 not in TENSOR_SIZES: - pytest.skip("m//2 must be in TENSOR_SIZES") - - manager = DummyLoRAManager() - - module_name = "module" - weight = torch.rand([m // 2, n], device="cuda", dtype=dtype) - - manager.init_random_lora(module_name + "1", weight, rank=rank) - lora_1 = manager.get_module_lora(module_name + "1") - manager.init_random_lora(module_name + "2", weight, rank=rank) - lora_2 = manager.get_module_lora(module_name + "2") - - input = torch.rand(k, n, device="cuda", dtype=dtype) - expected = torch.cat([ - input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, - input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling - ], - dim=1) - - lora_a_stacks = [ - torch.zeros(8, - 1, - lora_1.lora_a.shape[1], - lora_1.lora_a.shape[0], - device="cuda", - dtype=dtype) for i in range(2) - ] - lora_b_stacks = [ - torch.zeros(8, - 1, - lora_1.lora_b.shape[1], - lora_1.lora_b.shape[0], - device="cuda", - dtype=dtype) for i in range(2) - ] - for i in range(lora_a_stacks[0].shape[0]): - lora_a_stacks[0][i][0] = lora_1.lora_a.T - lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T - lora_a_stacks[1][i][0] = lora_2.lora_a.T - lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T - - output = torch.zeros(k, m, device="cuda", dtype=dtype) - _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, - torch.randint(0, - lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (m // 2, m // 2)) - - rtol, atol = TOLERANCES[dtype] - assert torch.allclose(expected, output, rtol=rtol, atol=atol) - - output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, - torch.full((len(input), ), -1, device="cuda"), - output, (m // 2, m // 2)) - assert torch.allclose(torch.zeros_like(output), output) - - manager.reset_lora() - - -@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) -@pytest.mark.parametrize("n", TENSOR_SIZES) -@pytest.mark.parametrize("k", BATCH_SIZES) -@pytest.mark.parametrize("rank", RANKS) -@pytest.mark.parametrize("dtype", DTYPES) -def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: - manager = DummyLoRAManager() - - module_name = "module" - weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype) - weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype) - - manager.init_random_lora(module_name + "q", weight_q, rank=rank) - lora_q = manager.get_module_lora(module_name + "q") - manager.init_random_lora(module_name + "k", weight_kv, rank=rank) - lora_k = manager.get_module_lora(module_name + "k") - manager.init_random_lora(module_name + "v", weight_kv, rank=rank) - lora_v = manager.get_module_lora(module_name + "v") - - input = torch.rand(k, n, device="cuda", dtype=dtype) - expected = torch.cat([ - input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, - input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, - input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling - ], - dim=1) - - lora_a_stacks = [ - torch.zeros(8, - 1, - lora_q.lora_a.shape[1], - lora_q.lora_a.shape[0], - device="cuda", - dtype=dtype) - ] + [ - torch.zeros(8, - 1, - lora_k.lora_a.shape[1], - lora_k.lora_a.shape[0], - device="cuda", - dtype=dtype) for i in range(2) - ] - lora_b_stacks = [ - torch.zeros(8, - 1, - lora_q.lora_b.shape[1], - lora_q.lora_b.shape[0], - device="cuda", - dtype=dtype) - ] + [ - torch.zeros(8, - 1, - lora_k.lora_b.shape[1], - lora_k.lora_b.shape[0], - device="cuda", - dtype=dtype) for i in range(2) - ] - for i in range(lora_a_stacks[0].shape[0]): - lora_a_stacks[0][i][0] = lora_q.lora_a.T - lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T - lora_a_stacks[1][i][0] = lora_k.lora_a.T - lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T - lora_a_stacks[2][i][0] = lora_v.lora_a.T - lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T - - output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) - _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, - torch.randint(0, - lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (qkv[0], qkv[1], qkv[2])) - - rtol, atol = TOLERANCES[dtype] - assert torch.allclose(expected, output, rtol=rtol, atol=atol) - - output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, - torch.full((len(input), ), -1, device="cuda"), - output, (qkv[0], qkv[1], qkv[2])) - assert torch.allclose(torch.zeros_like(output), output) - - manager.reset_lora() diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py deleted file mode 100644 index dbeb16cb21ad3..0000000000000 --- a/tests/lora/test_punica.py +++ /dev/null @@ -1,258 +0,0 @@ -# Based on code from https://github.com/punica-ai/punica - -import pytest -import torch - -import vllm.lora.punica as punica - - -def assert_close(a, b): - rtol, atol = { - torch.float16: (5e-3, 5e-3), - torch.bfloat16: (3e-2, 2e-2), - torch.float32: (None, None), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) - - -def _lora_ref_impl( - y_final: torch.Tensor, - x: torch.Tensor, - wa_T_all: torch.Tensor, - wb_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, -): - y_stage_1 = torch.empty( - (x.size(0), wa_T_all.size(-2)), - dtype=torch.float32, - device=x.device, - ) - bs = x.shape[0] - s = torch.tensor(scale, dtype=torch.float32, device=x.device) - for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): - xi = x[i].unsqueeze(0).to(torch.float32) - wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) - if wb_T_all is not None: - wb = wb_T_all[lora_idx, layer_idx].transpose(-1, - -2).to(torch.float32) - - tmp = xi @ wa - y_stage_1[i] = tmp.squeeze(0) - y_final[i] += ((tmp @ wb).squeeze(0) * - s if wb_T_all is not None else y_stage_1[i]) - return y_final, y_stage_1 - - -H1 = H2 = [ - 128, - 256, - 512, - 896, - 1024, - 1152, - 1216, - 1280, - 1536, - 1664, - 2048, - 2240, - 2304, - 2368, - 2432, - 2560, - 2752, - 3072, - 3328, - 3456, - 3584, - 3712, - 4096, - 4480, - 4608, - 4736, - 4864, - 5120, - 5504, - 5632, - 5888, - 6144, - 6400, - 6848, - 6912, - 7168, - 7424, - 8192, - 8960, - 9216, - 9472, - 10240, - 11008, - 11264, - 13824, - 14336, - 14784, - 14848, - 15360, - 18944, - 22016, - 22528, - 24576, - 27392, - 27648, - 29568, - 29696, - 32000, - 32256, - 32512, - 32768, - 33024, - 36864, - 43264, - 49152, - 49408, - 60544, - 60672, - 64000, - 64256, - 102400, - 102656, - 128000, - 128256, -] -H2 = [64] + H2 -R = [1, 2, 4] -SEED = [0xabcdabcd987] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("r", R) -@pytest.mark.parametrize("seed", SEED) -@torch.inference_mode() -def test_lora_a_extra_shapes(dtype_str, h1, r, seed): - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - bs = 32 - dtype = getattr(torch, dtype_str) - device = torch.device("cuda") - - wa_T_all = torch.randn(num_loras, - num_layers, - r, - h1, - dtype=dtype, - device=device) - indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype, device=device) - y = torch.randn(bs, r, dtype=dtype, device=device) - - y_ref = y.clone() - _lora_ref_impl( - y_ref, - x, - wa_T_all, - None, - indices, - layer_idx, - 1.0, - ) - - y_our = y.clone() - punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0) - - assert_close(y_ref, y_our) - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("h2", H2) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_lora_correctness(dtype_str, h1, h2, seed, device): - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - r = 8 - bs = 32 - scale = 0.123 - dtype = getattr(torch, dtype_str) - torch.set_default_device(device) - - wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype) - indices = torch.randint(num_loras, (bs, ), dtype=torch.long) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype) - y = torch.randn(bs, h2, dtype=dtype) - - y_ref = y.clone() - _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) - - y_our = y.clone() - punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx, - scale) - - assert_close(y_ref, y_our) - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("h2", H2) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_lora_correctness_slice(dtype_str, h1, h2, seed, device): - if h2 % 3 != 0 or h2 // 3 not in H1: - pytest.skip("h2 must be divisible by 3 and in supported shapes") - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - r = 8 - bs = 32 - scale = 0.123 - dtype = getattr(torch, dtype_str) - torch.set_default_device(device) - - wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - - indices = torch.randint(num_loras, (bs, ), dtype=torch.long) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype) - y = torch.randn(bs, h2, dtype=dtype) - s = h2 // 3 - - y_ref = y.clone() - _lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices, - layer_idx, scale) - _lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices, - layer_idx, scale) - _lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices, - layer_idx, scale) - - y_our = y.clone() - punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices, - layer_idx, scale, 0, s) - punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices, - layer_idx, scale, s, s) - punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices, - layer_idx, scale, s * 2, s) - - assert_close(y_ref[:, :s], y_our[:, :s]) - assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2]) - assert_close(y_ref[:, s * 2:], y_our[:, s * 2:]) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py new file mode 100644 index 0000000000000..c052568dc2e33 --- /dev/null +++ b/tests/lora/test_punica_sizes.py @@ -0,0 +1,408 @@ +""" +This script is mainly used to tests various hidden_sizes. We have collected the +hidden_sizes included in the LoRA models currently supported by vLLM. It tests +whether the corresponding Triton kernel can run normally when tensor parallelism +is set to [1, 2, 4, 8, 16, 32, 64]. +""" +import random +from unittest.mock import patch + +import pytest +import torch + +from vllm.lora.ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice +from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.triton_utils.libentry import LibEntry + +from .utils import (generate_data, generate_data_for_expand_nslices, + ref_torch_groupgemm) + +HIDDEN_SIZES = [ + 128, + 256, + 512, + 896, + 1024, + 1152, + 1216, + 1280, + 1536, + 1664, + 2048, + 2240, + 2304, + 2368, + 2432, + 2560, + 2752, + 3072, + 3328, + 3456, + 3584, + 3712, + 4096, + 4480, + 4608, + 4736, + 4864, + 5120, + 5504, + 5632, + 5888, + 6144, + 6400, + 6848, + 6912, + 7168, + 7424, + 8192, + 8960, + 9216, + 9472, + 10240, + 11008, + 11264, + 13824, + 14336, + 14784, + 14848, + 15360, + 18944, + 22016, + 22528, + 24576, + 27392, + 27648, + 29568, + 29696, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 43264, + 49152, + 49408, + 60544, + 60672, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, +] +#The size of TP +divisibility = [1, 2, 4, 8, 16, 32, 64] + +all_hidden_size = [] +for div in divisibility: + for hidden_size in HIDDEN_SIZES: + all_hidden_size.append(hidden_size // div) + +HIDDEN_SIZES = list(set(all_hidden_size)) + +BATCHES = [4] +NUM_LORA = [4] +DTYPES = [torch.float16, torch.bfloat16] +MAX_RANKS = [32] +SCALES = [0.5] +SEED = [0] +CUDA_DEVICES = [f"cuda:{0}"] + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_sgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + seq_length = 128 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + max_seq_length = seq_len_tensor.max() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + if op_type == "shrink": + sgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + scaling, + ) + else: + sgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel + from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel + + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + seq_length = 1 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + if op_type == "shrink": + # The current _bgmv_shrink_kernel does not require the libentry + # decoration. The purpose of adding this patch is to test the + # correctness of libentry. + with patch( + "vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel", + LibEntry(_bgmv_shrink_kernel), + ): + bgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + scaling, + ) + else: + # ditto + with patch( + "vllm.lora.ops.bgmv_expand._bgmv_expand_kernel", + LibEntry(_bgmv_expand_kernel), + ): + bgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel + + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + seq_length = 128 if op_type == "sgmv" else 1 + ( + inputs_tensor, + lora_weights_lst, + our_outputs, + ref_outputs, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + max_seq_length = seq_len_tensor.max() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + if op_type == "sgmv": + sgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + slice_offset, + hidden_size, + add_inputs=True, + ) + else: + # The current _bgmv_expand_slice_kernel does not require the + # libentry decoration. The purpose of adding this patch is to test + # the correctness of libentry. + with patch( + "vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel", + LibEntry(_bgmv_expand_slice_kernel), + ): + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_outputs[:, slice_offset:slice_offset + hidden_size], + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + 1.0, + op_type="expand", + ) + + slice_offset += hidden_size + assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py new file mode 100644 index 0000000000000..7e73ea67ee5f4 --- /dev/null +++ b/tests/lora/test_punica_variation.py @@ -0,0 +1,342 @@ +""" +This script is mainly used to test whether trtion kernels can run normally +under different conditions, including various batches, numbers of LoRA , and +maximum ranks. +""" +import random +from unittest.mock import patch + +import pytest +import torch + +from vllm.lora.ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice +from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.triton_utils.libentry import LibEntry + +from .utils import (generate_data, generate_data_for_expand_nslices, + ref_torch_groupgemm) + +HIDDEN_SIZES = [3424, 4096, 4097] + +BATCHES = [1, 4, 16, 32] +NUM_LORA = [1, 4, 8, 16, 32, 64, 128] +DTYPES = [torch.float16, torch.bfloat16] +MAX_RANKS = [1, 4, 8, 16, 32, 64, 128] +SCALES = [0.5] +SEED = [0] +CUDA_DEVICES = [f"cuda:{0}"] + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_sgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + seq_length = 128 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + max_seq_length = seq_len_tensor.max() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + if op_type == "shrink": + sgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + scaling, + ) + else: + sgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel + from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel + + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + seq_length = 1 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + if op_type == "shrink": + # The current _bgmv_shrink_kernel does not require the libentry + # decoration. The purpose of adding this patch is to test the + # correctness of libentry. + with patch( + "vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel", + LibEntry(_bgmv_shrink_kernel), + ): + bgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + scaling, + ) + else: + # ditto + with patch( + "vllm.lora.ops.bgmv_expand._bgmv_expand_kernel", + LibEntry(_bgmv_expand_kernel), + ): + bgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel + + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + seq_length = 128 if op_type == "sgmv" else 1 + ( + inputs_tensor, + lora_weights_lst, + our_outputs, + ref_outputs, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + max_seq_length = seq_len_tensor.max() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + if op_type == "sgmv": + sgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + slice_offset, + hidden_size, + add_inputs=True, + ) + else: + # The current _bgmv_expand_slice_kernel does not require the + # libentry decoration. The purpose of adding this patch is to test + # the correctness of libentry. + with patch( + "vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel", + LibEntry(_bgmv_expand_slice_kernel), + ): + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_outputs[:, slice_offset:slice_offset + hidden_size], + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + 1.0, + op_type="expand", + ) + + slice_offset += hidden_size + assert_close(our_outputs, ref_outputs) + + +if __name__ == "__main__": + from itertools import product + + lst = list( + product( + BATCHES, + NUM_LORA, + MAX_RANKS, + [1.0], + [torch.float16], + ["expand"], + SEED, + CUDA_DEVICES, + )) + for ele in lst: + test_punica_bgmv(*ele) + print(f"{ele},pass") diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 8fd968c69e58f..2370c693e9534 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -64,14 +64,16 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size): # if torch.cuda.device_count() < tp_size: # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - llm = vllm.LLM(model=model.model_path, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_model_len=400, - tensor_parallel_size=tp_size, - quantization=model.quantization, - trust_remote_code=True) + llm = vllm.LLM( + model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_model_len=400, + tensor_parallel_size=tp_size, + gpu_memory_utilization=0.2, #avoid OOM + quantization=model.quantization, + trust_remote_code=True) if model.quantization is None: expected_no_lora_output = [ @@ -156,24 +158,28 @@ def test_quant_model_tp_equality(tinyllama_lora_files, model): # if torch.cuda.device_count() < 2: # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") - llm_tp1 = vllm.LLM(model=model.model_path, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=1, - quantization=model.quantization, - trust_remote_code=True) + llm_tp1 = vllm.LLM( + model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1, + gpu_memory_utilization=0.2, #avoid OOM + quantization=model.quantization, + trust_remote_code=True) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) del llm_tp1 cleanup() - llm_tp2 = vllm.LLM(model=model.model_path, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=2, - quantization=model.quantization) + llm_tp2 = vllm.LLM( + model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2, + gpu_memory_utilization=0.2, #avoid OOM + quantization=model.quantization) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) del llm_tp2 diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b73cf5bf55324..00f8e26d1041f 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -86,3 +86,151 @@ def init_packed_lora( packed_lora = PackedLoRALayerWeights.pack(base_loras) self.set_module_lora(module_name, packed_lora) return packed_lora + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def ref_torch_groupgemm( + out_tensor, + inputs, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling, + op_type, +) -> torch.Tensor: + out_list = [] + current_offset = 0 + for lora_index, b_length in zip(range(batches), seq_len_tensor): + input_weight = inputs[current_offset:b_length + current_offset, :] + current_offset += b_length + lora_weight = lora_weights[lora_indices_tensor[lora_index]] + result = torch.nn.functional.linear(input_weight, lora_weight) + result *= scaling + out_list.append(result) + cat_result = torch.cat(out_list, dim=0) + if op_type == "expand": + out_tensor += cat_result + else: + out_tensor.copy_(cat_result) + return + + +def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype, + op_type, device): + seq_len_tensor = torch.randint(seq_length, seq_length + 1, + (batches, )).to(device) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0, + ).to(device) + total_tokens = seq_len_tensor.sum() + if op_type == "shrink": + inputs_tensor = torch.rand((total_tokens, hidden_size), + dtype=dtype).to(device) + lora_weights = torch.rand( + (lora_nums, max_rank, hidden_size), # col-major + dtype=dtype, + ).to(device) + # shrink op need atomic_add, so output is initinized by 0 + ref_out_tensor = torch.zeros((total_tokens, max_rank), + dtype=dtype, + device=inputs_tensor.device) + # NOTE shrink kernel using torch.float32 as output type + our_out_tensor = torch.zeros((total_tokens, max_rank), + dtype=torch.float32).to(device) + else: + inputs_tensor = torch.rand( + (total_tokens, max_rank), + dtype=dtype, + ).to(device) + lora_weights = torch.rand( + (lora_nums, hidden_size, max_rank), # col-major + dtype=dtype, + ).to(device) + # expand op needs to complete y+=a@lora_b, so output is + # initinized randomly + ref_out_tensor = torch.rand( + (total_tokens, hidden_size), + dtype=dtype, + ).to(device) + # Ensure the same input. + our_out_tensor = ref_out_tensor.clone() + lora_indices_tensor = torch.randint(0, + lora_nums - 1 if lora_nums > 1 else 1, + (batches, )).to(device) + indices = torch.zeros((total_tokens), dtype=torch.long).to(device) + current_offset = 0 + for b_id in range(batches): + lora_index = lora_indices_tensor[b_id] + indices[current_offset:current_offset + + seq_len_tensor[b_id]].copy_(lora_index) + current_offset += seq_len_tensor[b_id].item() + return ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) + + +def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank, + seq_length, dtype, nslices, device): + seq_len_tensor = torch.randint(seq_length, seq_length + 1, + (batches, )).to(device) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0, + ).to(device) + total_tokens = seq_len_tensor.sum() + inputs_tensor = torch.rand( + (total_tokens, max_rank), + dtype=dtype, + ).to(device) + lora_weights_lst = [] + for _ in range(nslices): + lora_weights_lst.append( + torch.rand( + (lora_nums, hidden_size, max_rank), # col-major + dtype=dtype, + ).to(device)) + # expand op needs to complete y+=a@lora_b, so output is + # initinized randomly + ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), + dtype=dtype).to(device) + # Ensure the same input. + our_out_tensor = ref_out_tensor.clone() + lora_indices_tensor = torch.randint(0, + lora_nums - 1 if lora_nums > 1 else 1, + (batches, )) + indices = torch.zeros((total_tokens), dtype=torch.long).to(device) + current_offset = 0 + for b_id in range(batches): + lora_index = lora_indices_tensor[b_id] + indices[current_offset:current_offset + + seq_len_tensor[b_id]] = lora_index.item() + current_offset += seq_len_tensor[b_id].item() + + lora_indices_tensor = lora_indices_tensor.to(device) + return ( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e5151c070f2f7..6dfaefd3cb51e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -13,12 +13,9 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) -with contextlib.suppress(ImportError): - import vllm._moe_C - with contextlib.suppress(ImportError): # ruff: noqa: F401 - import vllm._punica_C + import vllm._moe_C def is_custom_op_supported(op_name: str) -> bool: @@ -505,43 +502,6 @@ def register_graph_buffers(fa: int, handles: List[str], torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) -# punica -def dispatch_bgmv( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.Tensor, - layer_idx: int, - scale: float, -) -> None: - torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, - scale) - - -def dispatch_bgmv_low_level( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.Tensor, - layer_idx: int, - scale: float, - h_in: int, - h_out: int, - y_offset: int, -) -> None: - torch.ops._punica_C.dispatch_bgmv_low_level( - y, - x, - w_t_all, - indicies, - layer_idx, - scale, - h_in, - h_out, - y_offset, - ) - - # temporary fix for https://github.com/vllm-project/vllm/issues/5456 # TODO: remove this in v0.6.0 names_and_values = globals() diff --git a/vllm/envs.py b/vllm/envs.py index 595992e51db87..8cfc42c558641 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -43,7 +43,6 @@ MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False - VLLM_INSTALL_PUNICA_KERNELS: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False @@ -92,10 +91,6 @@ def get_default_config_root(): "VLLM_USE_PRECOMPILED": lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), - # If set, vllm will install Punica kernels - "VLLM_INSTALL_PUNICA_KERNELS": - lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))), - # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index d27171f720832..a7887a048746a 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -14,7 +14,6 @@ MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, RowParallelLinearWithLoRA) -from vllm.lora.punica import bgmv, dispatch_bgmv_low_level if TYPE_CHECKING: pass @@ -28,7 +27,7 @@ def _fully_sharded_can_replace(can_replace): def dec(*args, **kwargs): return (can_replace(*args, **kwargs) - and kwargs['lora_config'].fully_sharded_loras) + and kwargs["lora_config"].fully_sharded_loras) return dec @@ -59,25 +58,30 @@ def apply(self, x: torch.Tensor, x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), - dtype=torch.float32, - device=x.device) - - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + buffer = torch.zeros( + (x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device, + ) + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) buffer = tensor_model_parallel_all_gather(buffer) - bgmv(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + self.punica_wrapper.add_expand(output, + buffer, + self.lora_b_stacked, + add_input=True) # now have column partitioned output - output = output.view(*out_orig_shape) return output @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, @@ -88,14 +92,14 @@ def can_replace_layer(cls, source_layer: nn.Module, ) -def _mcp_apply(x, bias, layer): +def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): """ - MergedColumnParallelLinearWithShardedLoRA and - MergedQKVParallelLinearWithShardedLora share the same + MergedColumnParallelLinearWithShardedLoRA and + MergedQKVParallelLinearWithShardedLora share the same LoRa weight application method. The main difference is the step by shard_size for lora_b which can - vary for MergedQKVParallelLinearWithShardedLora but is constant for + vary for MergedQKVParallelLinearWithShardedLora but is constant for MergedColumnParallelLinearWithShardedLoRA. """ # expecting 2 for column parallel and 3 for qkv @@ -104,21 +108,27 @@ def _mcp_apply(x, bias, layer): x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device) + buffers = torch.zeros( + (n, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) for idx in range(n): - bgmv(buffers[idx], x, layer.lora_a_stacked[idx], - layer.indices[:layer.indices_len[0]], 0, 1.0) + layer.punica_wrapper.add_shrink(buffers[idx], x, + layer.lora_a_stacked[idx], 1.0) buffers = tensor_model_parallel_all_gather(buffers) left_offset = 0 for idx in range(n): shard_size = layer.lora_b_stacked[idx].shape[2] - dispatch_bgmv_low_level(output, buffers[idx], - layer.lora_b_stacked[idx], - layer.indices[:layer.indices_len[0]], 0, 1.0, - left_offset, shard_size) + layer.punica_wrapper.add_expand_slice( + output, + buffers[idx], + layer.lora_b_stacked[idx], + left_offset, + shard_size, + add_input=True, + ) left_offset += shard_size output = output.view(*out_orig_shape) @@ -129,7 +139,7 @@ def _mcp_apply(x, bias, layer): class MergedColumnParallelLinearWithShardedLoRA( MergedColumnParallelLinearWithLoRA): """ - Differs from MergedColumnParallelLinearWithLoRA by slicing the + Differs from MergedColumnParallelLinearWithLoRA by slicing the LoRA A's also. Based on S-LoRA, slicing happens along the rank dim. @@ -145,7 +155,8 @@ def slice_lora_a( lora_a = [ lora_a[0][:, output_start_idx:output_start_idx + output_shard_size], - lora_a[1][:, output_start_idx:output_start_idx + output_shard_size] + lora_a[1][:, + output_start_idx:output_start_idx + output_shard_size], ] return lora_a @@ -155,9 +166,13 @@ def apply(self, x: torch.Tensor, @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, @@ -170,7 +185,7 @@ def can_replace_layer(cls, source_layer: nn.Module, class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): """ - Differs from QKVParallelLinearWithLora by slicing the + Differs from QKVParallelLinearWithLora by slicing the LoRA A's also. Based on S-LoRA, slicing happens along the rank dim. @@ -193,14 +208,13 @@ def apply(self, x: torch.Tensor, buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), dtype=torch.float32, device=x.device) - - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) buffer = tensor_model_parallel_all_gather(buffer) - bgmv(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + self.punica_wrapper.add_expand(output, + buffer, + self.lora_b_stacked, + add_input=True) # now have column partitioned output - output = output.view(*out_orig_shape) return output @@ -237,7 +251,7 @@ def slice_lora_a( lora_a = [ lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]], lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]], - lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]] + lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]], ] return lora_a @@ -247,9 +261,13 @@ def apply(self, x: torch.Tensor, @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, @@ -262,11 +280,11 @@ def can_replace_layer(cls, source_layer: nn.Module, class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): """ - Differs from RowParallelLinearWithLoRA by slicing the + Differs from RowParallelLinearWithLoRA by slicing the LoRA B's also. Based on S-LoRA, slicing happens along the output dim. - This yields a combined partial sum from the row parallel base + This yields a combined partial sum from the row parallel base layer and column partitioned output from the LoRA. """ @@ -283,11 +301,13 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), - dtype=torch.float32, - device=x.device) - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + buffer = torch.zeros( + (x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device, + ) + + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce @@ -298,18 +318,21 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: # reduced before being used shard_size = self.lora_b_stacked.shape[2] start_idx = self.tp_rank * shard_size - dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0, - start_idx, shard_size) - + self.punica_wrapper.add_expand_slice(output, buffer, + self.lora_b_stacked, start_idx, + shard_size) output = output.view(*out_orig_shape) return output @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 40de134c0a5ee..40a7e01011f7a 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -17,7 +17,7 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_gather) from vllm.distributed.utils import divide -from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.lora.punica import PunicaWrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -55,88 +55,17 @@ def _not_fully_sharded_can_replace(can_replace): """ def dec(*args, **kwargs): - decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True - condition = (not kwargs['lora_config'].fully_sharded_loras + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras if decorate else True) return can_replace(*args, **kwargs) and condition return dec -def _apply_lora( - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - indices: torch.Tensor, - output: torch.Tensor, -): - """Applies lora to each input. - - This method applies all loras to each input. It uses the - indices vector to determine which lora yields the - correct output. An index of -1 means no lora should be - applied. This method adds the final lora results to the - output. - - Input shapes: - x: (batch_size, hidden_dim) - lora_a_stacked: (num_loras, lora_rank, hidden_dim) - lora_b_stacked: (num_loras, output_dim, lora_rank) - indices: (batch_size) - output: (batch_size, output_dim) - """ - org_output = output - x = x.view(-1, x.shape[-1]) - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) - return output.view_as(org_output) - - -def _apply_lora_packed_nslice( - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - indices: torch.Tensor, - output: torch.Tensor, - output_slices: Tuple[int, ...], -): - """Applies lora to each input. - - This method applies all loras to each input. It uses the - indices vector to determine which lora yields the - correct output. An index of -1 means no lora should be - applied. This method adds the final lora results to the - output. - - This method is used for layers that are composed of multiple sublayers - (slices) packed together. - - Input shapes: - x: (batch_size, hidden_dim) - lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) - lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - x = x.view(-1, x.shape[-1]) - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - offset_left = 0 - for slice_idx in range(len(output_slices)): - add_lora_slice(output, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, - output_slices[slice_idx]) - offset_left += output_slices[slice_idx] - return output.view_as(org_output) - - @dataclass class LoRAMapping(AdapterMapping): - pass + is_prefill: bool = False class BaseLayerWithLoRA(nn.Module): @@ -154,10 +83,11 @@ def slice_lora_b( ... def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: """Initializes lora matrices.""" ... @@ -177,20 +107,18 @@ def set_lora( def set_mapping( self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], + punica_wrapper: PunicaWrapper, ): - """Sets the mapping indices.""" - ... + self.punica_wrapper: PunicaWrapper = punica_wrapper @classmethod - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" raise NotImplementedError @@ -259,10 +187,6 @@ def create_lora_weights( self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[2], ) - # Lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - self.embeddings_indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -285,40 +209,27 @@ def set_lora( if embeddings_tensor is not None: self.embeddings_tensors[ index, :embeddings_tensor.shape[0], :embeddings_tensor. - shape[1]].copy_(embeddings_tensor, non_blocking=True) + shape[1], ].copy_(embeddings_tensor, non_blocking=True) if self.embeddings_slice is not None: # TODO(yard1): Optimize this copy, we don't need to copy # everything, just the modified part embeddings = self.embeddings_tensors.view( self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2] + self.embeddings_tensors.shape[2], )[self.embeddings_slice[0]:self.embeddings_slice[1]] assert self.embeddings_weights is not None self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.indices = base_indices - self.embeddings_indices = embeddings_indices - self.indices_len = indices_len - def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - embedding_len = self.indices_len[3] - indices = self.embeddings_indices[1][:embedding_len].view_as(x) + embeddings_indices = self.punica_wrapper.embeddings_indices + indices = embeddings_indices[1].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = self.embeddings_indices[0][:embedding_len].view_as(x) + indices = embeddings_indices[0].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) @@ -329,22 +240,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if full_lora_a_embeddings.ndim == 3: full_lora_a_embeddings = full_lora_a_embeddings.view( full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], -1) - bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + full_lora_a_embeddings.shape[1], + -1, + ) + + # Embedding layer only need expand op + self.punica_wrapper.add_expand(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) return full_output.view_as(full_output_org) @classmethod - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: return type(source_layer) is VocabParallelEmbedding class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): """ LoRA on top of ColumnParallelLinear layer. - + LoRA B is sliced for tensor parallelism. """ @@ -357,10 +278,11 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None: self.device = _get_lora_device(self.base_layer) def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() lora_a_output_size_per_partition = ( @@ -384,10 +306,6 @@ def create_lora_weights( ) self.output_dim = self.lora_b_stacked.shape[2] - # lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 @@ -423,28 +341,11 @@ def set_lora( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.indices = base_indices - self.indices_len = indices_len - def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - ) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) return output def forward(self, input_): @@ -473,9 +374,13 @@ def forward(self, input_): @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: return type(source_layer) is ColumnParallelLinear or ( type(source_layer) is MergedColumnParallelLinear and len(packed_modules_list) == 1) @@ -494,10 +399,11 @@ def __init__(self, base_layer: MergedColumnParallelLinear) -> None: super().__init__(base_layer) def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: self.lora_config = lora_config n_slices = 2 if not (len(self.base_layer.output_sizes) == n_slices @@ -533,8 +439,6 @@ def create_lora_weights( ) for _ in range(n_slices)) self.output_dim = self.lora_b_stacked[0].shape[2] - # Lazily initialized. - self.indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -556,7 +460,8 @@ def slice_lora_b( start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size lora_b = [ - lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx] + lora_b[0][:, start_idx:end_idx], + lora_b[1][:, start_idx:end_idx], ] return lora_b @@ -591,34 +496,33 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora_packed_nslice( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - (self.output_dim, self.output_dim), - ) + self.punica_wrapper.add_lora_packed_nslice( + output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, + (self.output_dim, self.output_dim)) return output @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is MergedColumnParallelLinear and len( - packed_modules_list) == 2 + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): """ - ColumnParallelLinear layer that is specifically designed for - qkv_proj. Certain models, such as chtglm3 and baichuan-7b, - only contains a single LoRA within their qkv_proj layer. + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chtglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. - During inference with Tensor Parallel, the weights of lora_b + During inference with Tensor Parallel, the weights of lora_b must be accurately partitioned according to the respective ranks. - + Q slice may have different shape than K and V slices (which both have the same shape). """ @@ -696,10 +600,11 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -767,11 +672,15 @@ def create_lora_weights( ), ) - self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, - self.kv_proj_shard_size) + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) self.packed_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None # lazily initialized. + self.indices: torch.Tensor self.indices_len: List[int] def reset_lora(self, index: int): @@ -794,15 +703,15 @@ def slice_lora_b( if lora_b[0] is not None: lora_b_q = lora_b[0][:, self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] + (self.q_shard_id + 1), ] if lora_b[1] is not None: lora_b_k = lora_b[1][:, self.kv_proj_shard_size * self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] + (self.kv_shard_id + 1), ] if lora_b[2] is not None: lora_b_v = lora_b[2][:, self.kv_proj_shard_size * self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] + (self.kv_shard_id + 1), ] lora_b = [lora_b_q, lora_b_k, lora_b_v] return lora_b @@ -851,23 +760,23 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora_packed_nslice( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - self.output_slices, - ) + self.punica_wrapper.add_lora_packed_nslice(output, x, + self.lora_a_stacked, + self.lora_b_stacked, 1.0, + self.output_slices) return output @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( - packed_modules_list) == 3 + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) class RowParallelLinearWithLoRA(BaseLayerWithLoRA): @@ -880,10 +789,11 @@ def __init__(self, base_layer: RowParallelLinear) -> None: self.device = _get_lora_device(self.base_layer) def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: self.lora_config = lora_config self.tp_rank = get_tensor_model_parallel_rank() self.lora_a_stacked = torch.zeros( @@ -911,9 +821,6 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) - # Lazily initialized - self.indices: torch.Tensor - self.indices_len: List[int] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -950,27 +857,10 @@ def set_lora( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.indices = base_indices - self.indices_len = indices_len - def apply(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) - _apply_lora( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - ) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) return output def forward(self, input_): @@ -1013,14 +903,18 @@ def forward(self, input_): @property def weight(self): - return self.base_layer.weight if hasattr( - self.base_layer, "weight") else self.base_layer.qweight + return (self.base_layer.weight if hasattr(self.base_layer, "weight") + else self.base_layer.qweight) @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: return type(source_layer) is RowParallelLinear @@ -1121,10 +1015,6 @@ def create_lora_weights( dtype=torch.long) else: self.sharded_to_full_mapping_gpu = None - # Lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - self.indices_padded: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -1150,19 +1040,6 @@ def set_lora( index, :embeddings_tensor.shape[0], :embeddings_tensor. shape[1], ] = embeddings_tensor - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.indices = sampler_indices - self.indices_padded = sampler_indices_padded - self.indices_len = indices_len - def _get_logits( self, hidden_states: torch.Tensor, @@ -1208,38 +1085,37 @@ def _get_logits( out=lora_logits[:-1]) lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], - ).index_select(0, - self.indices_padded[:self.indices_len[2]]).nan_to_num_( - nan=float("-inf"), - posinf=float("inf"), - neginf=float("-inf"))) + ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits - - _apply_lora( - hidden_states, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[1]], - logits, - ) + lora_logits.shape[1], ] = lora_logits + + # LogitsProcessorWithLoRA always using bgmv + self.punica_wrapper.add_lora_logits(logits, hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, 1.0) # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] - return logits def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) @classmethod - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # Special handling for the LogitsProcessor. return False @@ -1255,9 +1131,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): def __init__(self, base_layer: RotaryEmbedding) -> None: super().__init__() self.base_layer = base_layer - # Lazily initialized - self.long_lora_indices: torch.Tensor - self.indices_len: List[int] @property def scaling_factors(self): @@ -1273,9 +1146,8 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: - scaling_factors = list( - lora_config.long_lora_scaling_factors - ) if lora_config.long_lora_scaling_factors else [] + scaling_factors = (list(lora_config.long_lora_scaling_factors) + if lora_config.long_lora_scaling_factors else []) base_scaling_factor = (self.base_layer.scaling_factor if isinstance( self.base_layer, LinearScalingRotaryEmbedding) else 1.0) scaling_factors = sorted( @@ -1302,18 +1174,6 @@ def set_lora( ): ... - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.long_lora_indices = long_lora_indices - self.indices_len = indices_len - def forward( self, positions: torch.Tensor, @@ -1324,19 +1184,24 @@ def forward( positions, query, key, - offsets=self.long_lora_indices[:self.indices_len[4]]) + offsets=self.punica_wrapper.long_lora_indices, + ) @property def scaling_factor_to_offset(self) -> Dict[float, int]: return self.base_layer.scaling_factor_to_offset @classmethod - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" - return type(source_layer) is LinearScalingRotaryEmbedding or type( - source_layer) is RotaryEmbedding + return (type(source_layer) is LinearScalingRotaryEmbedding + or type(source_layer) is RotaryEmbedding) def extra_repr(self) -> str: return self.base_layer.extra_repr() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e1ede7d4d710a..017a1002bb9a7 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import os import re from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type import safetensors.torch import torch @@ -21,6 +21,7 @@ LinearScalingRotaryEmbeddingWithLora, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.punica import PunicaWrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA @@ -43,115 +44,6 @@ class LongContextLoRAContext: offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) -def convert_mapping( - mapping: LoRAMapping, - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional[LongContextLoRAContext] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: - """Converts LoRAMapping to index tensors. - - Args: - mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. - max_loras: Maximum number of LoRAs. - vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. - - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - LoRA indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - LoRA indices for sampler. For generation, this will be the - same as base_indicies. For prefill, this will map requests - to LoRA indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to LoRA indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. - indices_len: List of lengths of the above tensors. - Used to index into each tensor. It contains length for - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). If long_lora doesn't - exist, it only contains first 4 entries. - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - embedding_indices = index_mapping_indices.copy() - lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device="cuda", - dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - lora_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, lora_indices, embedding_indices - ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") - prompt_mapping_tensor = torch.tensor(prompt_mapping, - device="cuda", - dtype=torch.long) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size) - ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = ( - torch.arange( - 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + - (sampler_indices_padded * len(sampler_indices_padded))) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] - ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - - return (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices, indices_len) - - def get_lora_id(): global _GLOBAL_LORA_ID _GLOBAL_LORA_ID += 1 @@ -422,29 +314,12 @@ def __init__( self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None - self.base_indices = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.sampler_indices = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.embeddings_indices = torch.empty(2, - self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.long_lora_indices = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") + self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, + max_batches=self.max_num_seqs, + device="cuda") # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} - # 4 is the number of indicies tensors defined above - # base_indices, sampler_indices, sampler_indices_padded, - # embeddings_indices - self.indices_len: List[Optional[int]] = [None] * 4 super().__init__(model) if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( @@ -536,28 +411,16 @@ def pin_adapter(self, lora_id: int) -> bool: "Pinning is not supported in LoRAModelManager." "Use LRUCacheLoRAModelManager for pinning") # type: ignore - # TODO see if this can be vectorized def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_offsets_tensor, - indices_len) = convert_mapping(mapping, self.lora_index_to_id, - self.lora_slots + 1, self.vocab_size, - self.lora_config.lora_extra_vocab_size, - self.long_lora_context) - self.base_indices[:base_indices.shape[0]].copy_(base_indices) - self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self.embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) - if long_lora_offsets_tensor is not None: - self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self.long_lora_indices.zero_() - # Maintain the reference - self.indices_len[:] = indices_len + # update lora states + self.punica_wrapper.update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" @@ -595,10 +458,8 @@ def _create_lora_modules(self): self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) - new_module.set_mapping(self.base_indices, self.sampler_indices, - self.sampler_indices_padded, - self.embeddings_indices, - self.long_lora_indices, self.indices_len) + # All lora layers share the same punica_wrapper based on reference. + new_module.set_mapping(self.punica_wrapper) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA) diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py new file mode 100644 index 0000000000000..dcaf2e3d462cc --- /dev/null +++ b/vllm/lora/ops/bgmv_expand.py @@ -0,0 +1,169 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + current_n_c = tl.max_contiguous(current_n, BLOCK_N) + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n_c[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + batches (int): batch size + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + batches = lora_indices_tensor.size(0) + if override_config: + config = override_config + else: + config = get_lora_op_configs("expand", batches, N) + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py new file mode 100644 index 0000000000000..fa6571074f3ab --- /dev/null +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -0,0 +1,182 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + + slice_offset * cn_stride) + + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + slice_offst (int): output_tensor's offst + slice_size (int): current output_tensor's size + batches (int): batch size + add_inputs (bool, optional): Defaults to False. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + batches = lora_indices_tensor.size(0) + + if override_config: + config = override_config + else: + config = get_lora_op_configs("expand", batches, N) + + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py new file mode 100644 index 0000000000000..e69d33078f5aa --- /dev/null +++ b/vllm/lora/ops/bgmv_shrink.py @@ -0,0 +1,150 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's + performance + """ + pid_sk = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + + offset_n = tl.arange(0, BLOCK_N) + offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K + a_ptr = input_ptr + cur_batch * xm_stride + b_ptr = lora_ptr + l0_stride * lora_index + accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32) + for k in range(0, K, BLOCK_K * SPLIT_K): + current_k = k + offset_k + current_k_c = tl.max_contiguous(current_k, BLOCK_K) + tiled_a = tl.load( + a_ptr + current_k_c, + mask=current_k < K, + other=0.0, + ) # [BLOCK_K] + b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K) + + tiled_b = tl.load( + b_ptr + offset_n[:, None] * lora_k_stride + + current_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + accumulator += tl.sum(tiled_a * tiled_b, 1) + accumulator *= scaling + offset_cn = tl.arange(0, BLOCK_N) + c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride + c_mask = offset_cn < N + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + scaling (float): Scaling factor. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + batches = lora_indices_tensor.size(0) + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_N = triton.next_power_of_2(N) + if override_config: + config = override_config + else: + # First try to load optimal config from the file + config = get_lora_op_configs("bgmv_shrink", batches, K) + + grid = lambda META: ( + META["SPLIT_K"], + batches, + ) + _bgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_N=BLOCK_N, + **config, + ) + return diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py new file mode 100644 index 0000000000000..4590495469096 --- /dev/null +++ b/vllm/lora/ops/sgmv_expand.py @@ -0,0 +1,192 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + The sgmv's expand triton kernel is based on GroupGEMM. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride, ) + b_ptr = (lora_ptr + l0_stride * lora_index + + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + add_inputs: bool = False, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py new file mode 100644 index 0000000000000..ff3bcda071b80 --- /dev/null +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -0,0 +1,205 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride, ) + b_ptr = (lora_ptr + l0_stride * lora_index + + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < + (slice_offset + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): + """_summary_ + + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + slice_offst (int): output_tensor's offst + slice_size (int): current output_tensor's size + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output.. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py new file mode 100644 index 0000000000000..8ab049989abef --- /dev/null +++ b/vllm/lora/ops/sgmv_shrink.py @@ -0,0 +1,189 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + xm_stride, # hidden_size + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. + The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, + introducing SPLIT-K can improve performance + """ + pid = tl.program_id(axis=0) + pid_sk = tl.program_id(axis=1) + cur_batch = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride) + b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + + offset_k[:, None] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < k_remaining, + other=0.0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < k_remaining, + other=0.0) + accumulator += tl.dot(tiled_a, tiled_b) + + a_ptr += BLOCK_K * SPLIT_K * xk_stride + b_ptr += BLOCK_K * SPLIT_K * lora_n_stride + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + scaling: float, +): + """ + + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 32 + SPLIT_K = 8 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + SPLIT_K, + batches, + ) + + _sgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + ) + return diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/utils.py new file mode 100644 index 0000000000000..7c3e27313ad97 --- /dev/null +++ b/vllm/lora/ops/utils.py @@ -0,0 +1,46 @@ +import functools +from typing import Dict + + +@functools.lru_cache +def _get_op_configs(op_type: str, batch: int, hidden_size: int): + # TODO: add optimal configurations + return None + + +def _check_divisibility(hidden_size: int): + # The bgmv_expand kernel requires that the hidden_size be divisible by + # the number below. + divisibility = [2, 4, 8, 16, 32, 64] + divisibility.sort(reverse=True) + for div in divisibility: + if hidden_size % div == 0: + return div + # hidden_size is an odd number + return 1 + + +def _get_default_config(op_type: str, batch: int, hidden_size: int): + if op_type == "expand": + return { + "BLOCK_N": 256, + "SPLIT_N": _check_divisibility(hidden_size), + "num_warps": 8 + } + else: + return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} + + +def get_lora_op_configs(op_type: str, batch: int, + hidden_size: int) -> Dict[str, int]: + """Inspired by `fused_moe_kernel` + The return value will be a dictionary mapping an irregular grid of batch + sizes and hidden_size to configurations of the bgmv-related kernel. + NOTE: It currently only supports the default configuration. We plan to + generate optimal configurations for different hardware in the future using + scripts similar to `benchmark_moe.py`. + """ + config = _get_op_configs(op_type, batch, hidden_size) + if not config: + config = _get_default_config(op_type, batch, hidden_size) + return config diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 64f87a4b2c69d..6d5c834299961 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -1,207 +1,604 @@ -# Based on code from https://github.com/punica-ai/punica +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" -from typing import Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union import torch -from vllm import _custom_ops as ops -from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON +if HAS_TRITON: + from vllm.lora.ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.sgmv_shrink import sgmv_shrink -def _check_punica_support(): - if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): - return +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext - if current_platform.get_device_capability() < (8, 0): - raise ImportError( - "punica LoRA kernels require compute capability >= 8.0") - else: - raise ImportError( - "punica LoRA kernels could not be imported. If you built vLLM " - "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " - "was set.") - - -def bgmv( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, -): - """ - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight - matrices. - indicies: Shape: `[B]`. Indices of the weight matrices. - layer_idx: Layer index of the weight matrices. - scale: Scaling factor. +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. """ - _check_punica_support() - ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, no_lora) -def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, indicies: torch.LongTensor, - layer_idx: int, scale: float, y_offset: int, - y_slice_size: int): - """ - Same as `bgmv` but you can operate on slices of y. - Pass whole y, define y_offset and y_slice_size. - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of - all of the transposed LoRA matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - y_offset: Offset to apply to the starting column of y. - y_slice_size: Size of the y column slice. + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). """ - _check_punica_support() - - ops.dispatch_bgmv_low_level( - y, - x, - w_t_all, - indicies, - layer_idx, - scale, - x.size(1), - y_slice_size, - y_offset, - ) + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device="cuda", + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) -def add_lora(y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, - *, - buffer: Optional[torch.Tensor] = None): - """ - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed - LoRA A matrices. - wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed - LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - buffer: Optional. Shape: `[B, R]`. Temporary buffer. + +class PunicaWrapper: """ - _check_punica_support() - - r = wb_t_all.size(-1) - if buffer is None: - # We set the buffer to be float32 by default to avoid - # numerical inaccuracies that would otherwise happen - # due to downcasting. - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) - ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale) - - -def add_lora_slice(y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, - y_offset: int, - y_slice_size: int, - *, - buffer: Optional[torch.Tensor] = None): + PunicaWrapper is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica kernel. """ - Same as `add_lora` but you can operate on slices of y. - Pass whole y, define y_offset and y_slice_size. - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: str): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed - LoRA A matrices. - wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed - LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - y_offset: Offset to apply to the starting column of y. - y_slice_size: Size of the y column slice. - """ - _check_punica_support() - - r = wb_t_all.size(-1) - if buffer is None: - # We set the buffer to be float32 by default to avoid - # numerical inaccuracies that would otherwise happen - # due to downcasting. - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - ops.dispatch_bgmv_low_level( - buffer, - x, - wa_t_all, - indicies, - layer_idx, - 1.0, - x.size(1), - buffer.size(1), - 0, - ) - ops.dispatch_bgmv_low_level( - y, - buffer, - wb_t_all, - indicies, - layer_idx, - scale, - buffer.size(1), - y_slice_size, - y_offset, - ) + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.max_length: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.no_lora = no_lora + + @property + def prefill_metadata( + self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions + 2. seq_lengths: Tensor of sequence lengths + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: batch size after clustering identical lora indices + 5. max_length: The maximum sequence length in the batch + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_input) + + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the shrink_decode function + should be called. + """ + shrink_fun: Callable = (self.shrink_prefill + if self.is_prefill else self.shrink_decode) + shrink_fun(y, x, w_t_all, scale) + + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'b. + When `is_prefill` is true, it indicates that it is currently the + prefill stage, and the `expand_prefill` function should be called. + Otherwise, it is the decode stage, and the expand_decode function + should be called. + """ + + expand_fun: Callable = (self.expand_prefill + if self.is_prefill else self.expand_decode) + expand_fun(y, x, w_t_all, add_input) + + def add_expand_slice(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True): + """ + Similar to `add_expand` + """ + + expand_slice_fun: Callable = (self.expand_slice_prefill + if self.is_prefill else + self.expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def add_lora(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + wa_t_all (torch.Tensor): lora_a's weight + wb_t_all (torch.Tensor): lora_b's weight + scale (float): Scaling factor. + y_offset (Optional[int], optional): Offset to apply to the starting + column of y. + y_slice_size (Optional[int], optional): Size of the y column slice.. + buffer (Optional[torch.Tensor], optional): Defaults to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + self.add_shrink(buffer, x, wa_t_all, scale) + if y_offset is None and y_slice_size is None: + self.add_expand(y, buffer, wb_t_all, add_input=True) + else: + self.add_expand_slice(y, + buffer, + wb_t_all, + y_offset, + y_slice_size, + add_input=True) + y = y.view_as(y_org) + + def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + scale: float, + output_slices: Tuple[int, ...]) -> None: + """ + Applies lora to each input. Similar to add_lora, This method is + used for layers that are composed of multiple sublayers + (slices) packed together. + """ + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + # TODO fuse these kernels + for slice_idx in range(len(output_slices)): + self.add_lora(y, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], scale, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + LogitsProcessorWithLoRA always using bgmv + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 09843e5d1f30b..3f57c22e1f2e4 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,6 +1,11 @@ -from vllm.triton_utils.custom_cache_manager import ( - maybe_set_triton_cache_manager) +from vllm.triton_utils.importing import HAS_TRITON -__all__ = [ - "maybe_set_triton_cache_manager", -] +__all__ = ["HAS_TRITON"] + +if HAS_TRITON: + + from vllm.triton_utils.custom_cache_manager import ( + maybe_set_triton_cache_manager) + from vllm.triton_utils.libentry import libentry + + __all__ += ["maybe_set_triton_cache_manager", "libentry"] diff --git a/vllm/triton_utils/libentry.py b/vllm/triton_utils/libentry.py new file mode 100644 index 0000000000000..ae00af44a048a --- /dev/null +++ b/vllm/triton_utils/libentry.py @@ -0,0 +1,167 @@ +# Copied From https://github.com/FlagOpen/FlagGems + +import inspect + +import triton + + +class LibEntry(triton.KernelInterface): + + def __init__( + self, + fn, + ): + self.fn = fn + self.arg_names = fn.arg_names + self.divisibility = 16 + self.kernel_cache = dict() + fn = self.fn + while not isinstance(fn, triton.runtime.JITFunction): + fn = fn.fn + self.jit_function: triton.runtime.JITFunction = fn + self.specialize_indices = [ + p.num for p in self.jit_function.params + if not p.is_constexpr and not p.do_not_specialize + ] + self.do_not_specialize_indices = [ + p.num for p in self.jit_function.params + if not p.is_constexpr and p.do_not_specialize + ] + + def key(self, spec_args, dns_args, const_args): + spec_key = [(arg.dtype, arg.data_ptr() % + self.divisibility == 0) if hasattr(arg, "data_ptr") else + (type(arg), arg) for arg in spec_args] + dns_key = [ + arg.dtype if hasattr( + arg, "data_ptr") else type(arg) if not isinstance(arg, int) + else "i32" if -(2**31) <= arg and arg <= 2**31 - + 1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" + for arg in dns_args + ] + # const args passed by position + return tuple(spec_key + dns_key + const_args) + + def run(self, *args, **kwargs): + grid = kwargs["grid"] + # collect all the arguments + spec_args = [] # specialize arguments + dns_args = [] # do not specialize arguments + const_args = [] # constexpr arguments + k_args = [] # kernel arguments + for i, arg in enumerate(args): + if i in self.specialize_indices: + k_args.append(arg) + spec_args.append(arg) + elif i in self.do_not_specialize_indices: + k_args.append(arg) + dns_args.append(arg) + else: + const_args.append(arg) + for p in self.jit_function.params[len(args):]: + if p.name in kwargs: + val = kwargs[p.name] + elif p.default is inspect._empty: + continue + else: + val = p.default + + if p.is_constexpr: + const_args.append(val) + elif p.do_not_specialize: + dns_args.append(val) + k_args.append(val) + else: + spec_args.append(val) + k_args.append(val) + + entry_key = self.key(spec_args, dns_args, const_args) + + if entry_key not in self.kernel_cache: + # compile the kernel also completes the related computations + kernel = self.fn.run(*args, **kwargs) + fn = self.fn + # collect constexpr arguments for grid computation + constexprs = {} + while not isinstance(fn, triton.runtime.JITFunction): + if isinstance(fn, triton.runtime.Autotuner): + config = fn.best_config + constexprs["num_warps"] = config.num_warps + constexprs["num_stages"] = config.num_stages + constexprs["num_ctas"] = config.num_ctas + constexprs = {**constexprs, **config.kwargs} + elif isinstance(fn, triton.runtime.Heuristics): + for v, heur in fn.values.items(): + constexprs[v] = heur({ + **dict(zip(fn.arg_names, args)), + **kwargs, + **constexprs, + }) + else: + raise RuntimeError("Invalid Runtime Function") + fn = fn.fn + # In vLLM, certain kernels like fused_moe_kernel get the + # best_config(as kwargs) from a configuration json file, rather + # than using Autotuner & Heuristics. Therefore, all their constexprs + # (tl.constexpr) are assigned values through the following loop. + for p in self.jit_function.params: + if p.is_constexpr and p.name not in constexprs: + constexprs[p.name] = p.default #default=inspect._empty + self.kernel_cache[entry_key] = (kernel, constexprs) + else: + # load kernel from cache directly + kernel, constexprs = self.kernel_cache[entry_key] + + if callable(grid): + # collect all arguments to the grid fn,ie: + # 1. args, + # 2. kwargs, + # 3. all all other captured arguments in CompiledKernel from + # Autotunner & Heuristics when kwargs & captured args conflict, + # captured args have higher priority + # 4. We must filter out captured args with default value firstly + constexprs = { + k: v + for k, v in constexprs.items() if v is not inspect._empty + } + meta = { + **dict(zip(self.arg_names, args)), + **kwargs, + **constexprs, + } + grid = grid(meta) + if isinstance(grid, tuple): + grid = grid + (1, 1) + elif isinstance(grid, list): + grid = grid + [1, 1] + kernel[grid[0:3]](*k_args) + # maintaining the same return type as the JITFunction.run + return kernel + + +def libentry(): + """ + Decorator for triton library entries. + Motivation: + The runtime overhead of Triton kernels is the reason for the lower + performance of small kernels, particularly evident with smaller models. + Using this decorator can reduce Triton runtime overhead. + How: + The `run` function of JITFunction needs to accomplish: + - Parameter binding using inspect + - KernelArg type wrapping + - Cache key calculation + When dealing with small size, these steps can become bottlenecks in + Triton runtime. Libentry simplifies these steps to reduce runtime + overhead, thereby improving the runtime expenses of small kernels. + NOTE: + When Triton is upgraded to version 3.0.0, libentry can be removed, + see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245 + + + """ + + def decorator(fn): + return LibEntry(fn) + + return decorator diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e63be184af16a..26aa2a11382ec 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -539,9 +539,9 @@ def build(self) -> ModelInputForGPU: for inter_data in self.inter_data_list ]) lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=not self.decode_only)) # Prompt adapter data. prompt_adapter_requests: Set[PromptAdapterRequest] = set() @@ -1114,9 +1114,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: if self.lora_config: lora_mapping = LoRAMapping( - [0] * batch_size, - [0] * batch_size, - ) + **dict(index_mapping=[0] * batch_size, + prompt_mapping=[0] * batch_size, + is_prefill=False)) self.set_active_loras(set(), lora_mapping) if self.prompt_adapter_config: