From d292f5c1b044bfed74faa5aebe35364f3603e2af Mon Sep 17 00:00:00 2001
From: laiwen <80147768+laiwenzh@users.noreply.github.com>
Date: Mon, 16 Dec 2024 16:14:24 +0800
Subject: [PATCH] update docs, build scripts, dockerfiles, workflow yaml, and
some bugfix (#45)
---
.github/workflows/release_packages.yml | 198 ++--
.gitignore | 28 +-
CMakeLists.txt | 77 +-
README.md | 120 +-
README_CN.md | 110 +-
build.sh | 9 +-
cmake/FindNCCL.cmake | 27 +-
cmake/flash-attention.cmake | 2 +-
cmake/install.cmake | 52 +-
csrc/service/CMakeLists.txt | 3 +-
docs/sphinx/devel/source_code_build_en.rst | 92 +-
.../sphinx/get_started/env_var_options_en.rst | 2 +-
docs/sphinx/get_started/install_en.md | 12 +-
docs/sphinx/llm/prefix_caching.rst | 48 +-
docs/sphinx/llm/runtime_config.rst | 5 +-
examples/cpp/0_basic/example_qwen.cpp | 4 +-
examples/cpp/1_apiserver/apiserver.cpp | 4 +-
examples/cpp/CMakeLists.txt | 11 +-
python/CMakeLists.txt | 4 +-
scripts/docker/dev_arm_centos8.Dockerfile | 9 +
scripts/docker/dev_cuda_124.Dockerfile | 40 +-
scripts/docker/dev_x86_centos7.Dockerfile | 30 +-
scripts/docker/test_cuda_ubuntu.Dockerfile | 45 +
scripts/release/cpp_build_cuda.sh | 11 +
scripts/release/python_manylinux_build.sh | 13 +-
.../release/python_manylinux_build_cuda.sh | 18 +-
.../conv/collective/builders/sm90_common.inl | 96 ++
.../collective/builders/sm90_gmma_builder.inl | 257 +++++
.../collective/builders/sm90_builder.inl | 797 +++++++++++++
.../collective/builders/sm90_common.inl | 80 ++
.../gemm/collective/builders/sm90_common.inl | 364 ++++++
.../collective/builders/sm90_gmma_builder.inl | 1003 +++++++++++++++++
.../building_in_windows_with_visual_studio.md | 90 ++
.../building_with_clang_as_host_compiler.md | 59 +
34 files changed, 3207 insertions(+), 513 deletions(-)
create mode 100644 scripts/docker/test_cuda_ubuntu.Dockerfile
create mode 100644 scripts/release/cpp_build_cuda.sh
create mode 100644 span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl
create mode 100644 span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl
create mode 100644 span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
create mode 100644 span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl
create mode 100644 span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl
create mode 100644 span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
create mode 100644 span-attention/thirdparty/cutlass/media/docs/build/building_in_windows_with_visual_studio.md
create mode 100644 span-attention/thirdparty/cutlass/media/docs/build/building_with_clang_as_host_compiler.md
diff --git a/.github/workflows/release_packages.yml b/.github/workflows/release_packages.yml
index ba97d161..54f7d856 100644
--- a/.github/workflows/release_packages.yml
+++ b/.github/workflows/release_packages.yml
@@ -1,57 +1,28 @@
name: Release Packages
on:
- push:
- tags:
- - 'v[0-9]+.[0-9]+.[0-9]+'
+ push:
+ tags:
+ - 'v[0-9]+.[0-9]+.[0-9]+'
+ workflow_dispatch:
# Needed to create release and upload assets
permissions:
contents: write
jobs:
- build-deb:
- runs-on: [self-hosted, Linux, X64]
- container:
- image: registry-1.docker.io/dashinfer/dev-ubuntu-22.04-x86:v1
- defaults:
- run:
- shell: bash -l {0}
- steps:
- - name: Check out code
- uses: actions/checkout@v3
-
- - name: Pull LFS
- run: |
- git lfs pull
-
- - name: Build deb package
- run: |
- git fetch --tags
- TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1))
- VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//')
- source activate ds_py
- AS_RELEASE_VERSION=$VERSION_NUMBER \
- AS_PLATFORM="x86" \
- AS_BUILD_PACKAGE=ON \
- bash build.sh
-
- - name: Upload deb package
- uses: actions/upload-artifact@v3
- with:
- name: dashinfer-deb
- path: build/*.deb
-
- build-rpm:
+ build-tgz:
strategy:
matrix:
arch: [X64, ARM64]
- image: ["dev-centos7-x86:v1", "dev-alinux-arm:v1"]
+ image: ["dev-centos7-x86:v2", "dev-centos7-cu124:v1", "dev-centos8-arm:v2"]
exclude:
- arch: X64
- image: "dev-alinux-arm:v1"
+ image: "dev-centos8-arm:v2"
+ - arch: ARM64
+ image: "dev-centos7-x86:v2"
- arch: ARM64
- image: "dev-centos7-x86:v1"
+ image: "dev-centos7-cu124:v1"
runs-on: [self-hosted, Linux, "${{ matrix.arch }}"]
container:
image: registry-1.docker.io/dashinfer/${{ matrix.image }}
@@ -67,40 +38,61 @@ jobs:
uses: actions/checkout@v3
with:
lfs: true
-
+
- name: Pull LFS
run: |
+ git lfs install --force
git lfs pull
-
- - name: Build rpm package
+
+ - name: Init submodule
+ run: |
+ git submodule init
+ git submodule update
+
+ - name: Build tgz package
run: |
git fetch --tags
TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1))
VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//')
- source /opt/rh/devtoolset-7/enable
- source activate ds_py
- AS_RELEASE_VERSION=$VERSION_NUMBER \
- AS_PLATFORM=$( [[ "${{ matrix.arch }}" = "X64" ]] && echo "x86" || echo "armclang" ) \
- AS_BUILD_PACKAGE=ON \
- bash build.sh
+ source /root/.bashrc
+
+ export AS_RELEASE_VERSION=$VERSION_NUMBER
+ export AS_BUILD_PACKAGE=ON
+
+ if command -v nvcc &> /dev/null
+ then
+ export AS_PLATFORM="cuda"
+ export AS_CUDA_SM="'70;75;80;86;89;90a'"
+ bash scripts/release/cpp_build_cuda.sh
+ else
+ # export ENABLE_MULTINUMA="ON"
+ if [[ "${{ matrix.arch }}" == "ARM64" ]]; then
+ export AS_PLATFORM="armclang"
+ else
+ export AS_PLATFORM="x86"
+ fi
+ bash build.sh
+ fi
- - name: Upload rpm package
+ - name: Upload tgz package
uses: actions/upload-artifact@v3
with:
- name: dashinfer-rpm-${{ matrix.arch }}
- path: build/*.rpm
-
+ name: dashinfer-tgz-${{ matrix.arch }}
+ path: build/*.tar.gz
+
build-wheels:
strategy:
matrix:
arch: [X64, ARM64]
- image: ["dev-manylinux-x86:v1", "dev-manylinux-arm:v1"]
+ image: ["dev-centos7-x86:v2", "dev-centos7-cu124:v1", "dev-centos8-arm:v2"]
exclude:
- arch: X64
- image: "dev-manylinux-arm:v1"
+ image: "dev-centos8-arm:v2"
- arch: ARM64
- image: "dev-manylinux-x86:v1"
+ image: "dev-centos7-x86:v2"
+ - arch: ARM64
+ image: "dev-centos7-cu124:v1"
runs-on: [self-hosted, Linux, "${{ matrix.arch }}"]
container:
image: registry-1.docker.io/dashinfer/${{ matrix.image }}
@@ -114,12 +106,31 @@ jobs:
with:
lfs: true
+ - name: Pull LFS
+ run: |
+ git lfs install --force
+ git lfs pull
+
+ - name: Init submodule
+ run: |
+ git submodule init
+ git submodule update
+
- name: Build manylinux wheels
run: |
git fetch --tags
TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1))
+ source /root/.bashrc
VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//')
- AS_RELEASE_VERSION=$VERSION_NUMBER bash scripts/release/python_manylinux_build.sh
+ export AS_RELEASE_VERSION=$VERSION_NUMBER
+
+ if command -v nvcc &> /dev/null
+ then
+ export AS_CUDA_SM="'70;75;80;86;89;90a'"
+ bash scripts/release/python_manylinux_build_cuda.sh
+ else
+ bash scripts/release/python_manylinux_build.sh
+ fi
- name: Upload wheels
uses: actions/upload-artifact@v3
@@ -127,56 +138,50 @@ jobs:
name: python-manylinux-wheels-${{ matrix.arch }}
path: python/wheelhouse/*-manylinux*.whl
- test:
- strategy:
- matrix:
- arch: [X64, ARM64]
- image: ["test-ubuntu-x86:v1", "test-centos-arm:v1"]
- exclude:
- - arch: X64
- image: "test-centos-arm:v1"
- - arch: ARM64
- image: "test-ubuntu-x86:v1"
- runs-on: [self-hosted, Linux, "${{ matrix.arch }}"]
- container:
- image: registry-1.docker.io/dashinfer/${{ matrix.image }}
- volumes:
- - /mnt/data0/models/modelscope:/github/home/.cache/modelscope
- options: "--ipc=host --cap-add SYS_NICE --cap-add SYS_PTRACE"
- needs: build-wheels
- steps:
- - name: Check out code
- uses: actions/checkout@v3
+ # test:
+ # strategy:
+ # matrix:
+ # arch: [X64, ARM64]
+ # image: ["test-ubuntu-x86:v1", "test-centos-arm:v1"]
+ # exclude:
+ # - arch: X64
+ # image: "test-centos-arm:v1"
+ # - arch: ARM64
+ # image: "test-ubuntu-x86:v1"
+ # runs-on: [self-hosted, Linux, "${{ matrix.arch }}"]
+ # container:
+ # image: registry-1.docker.io/dashinfer/${{ matrix.image }}
+ # volumes:
+ # - /mnt/data0/models/modelscope:/github/home/.cache/modelscope
+ # options: "--ipc=host --cap-add SYS_NICE --cap-add SYS_PTRACE"
+ # needs: build-wheels
+ # steps:
+ # - name: Check out code
+ # uses: actions/checkout@v3
- - name: Download wheels
- uses: actions/download-artifact@v3
- with:
- name: python-manylinux-wheels-${{ matrix.arch }}
- path: python/wheelhouse
+ # - name: Download wheels
+ # uses: actions/download-artifact@v3
+ # with:
+ # name: python-manylinux-wheels-${{ matrix.arch }}
+ # path: python/wheelhouse
- - name: Test manylinux wheels
- run: |
- TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1))
- VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//')
- AS_RELEASE_VERSION=$VERSION_NUMBER bash scripts/release/python_manylinux_test.sh
+ # - name: Test manylinux wheels
+ # run: |
+ # TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1))
+ # VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//')
+ # AS_RELEASE_VERSION=$VERSION_NUMBER bash scripts/release/python_manylinux_test.sh
publish:
runs-on: [self-hosted, Linux]
- needs: [build-deb, build-rpm, test]
+ needs: [build-tgz, build-wheels]
strategy:
matrix:
arch: [X64, ARM64]
steps:
- - name: Download deb packages
+ - name: Download tgz packages
uses: actions/download-artifact@v3
with:
- name: dashinfer-deb
- path: release/
-
- - name: Download rpm packages
- uses: actions/download-artifact@v3
- with:
- name: dashinfer-rpm-${{ matrix.arch }}
+ name: dashinfer-tgz-${{ matrix.arch }}
path: release/
- name: Download python wheels
@@ -189,6 +194,3 @@ jobs:
uses: softprops/action-gh-release@v2
with:
files: release/*
-
-
-
diff --git a/.gitignore b/.gitignore
index 21cb64f2..91d613ee 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,29 @@
+tests/cpp/model/testcase
+tests/cpp/operator/testcase
+tests/python/custom_model
+tests/testcase
build/
python/build/
+ossutil_output/
+__pycache__/
+.ccls
+*.qdrep
+*.qdstrm
+*.h5
+.ccls-cache/
+*.log
+compile_commands.json
python/dist/
-python/dashinfer.egg-info/
-python/dashinfer.egg-info
+python/pyhie.egg-info/
+python/pyhie_allspark.egg-info
+*.ascache
+*.lock
third_party/from_source/*.o
-__pycache__/
+third_party/from_source/openssl/*
+.idea/
+.vscode/
+*.nsys-rep
+log*
+*.csv
+#*.sh
+*.as*
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5369d32b..e330c148 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -9,13 +9,18 @@ endif()
string(REGEX REPLACE "-rc[0-9]+" "" STRIPED_VERSION_STRING ${project_version_in_env})
set(project_version_in_env ${STRIPED_VERSION_STRING})
-message("Build AllSpark with version:${project_version_in_env}")
+message("Build DashInfer with version: ${project_version_in_env}")
project(DashInfer LANGUAGES C CXX VERSION ${project_version_in_env})
include(GNUInstallDirs)
set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}-${PROJECT_VERSION} CACHE STRING "Force modify install dir" FORCE)
-message(STATUS "CMAKE_INSTALL_PREFIX:${CMAKE_INSTALL_PREFIX} CPACK_PACKAGE_DEVICE_NAME:${CPACK_PACKAGE_DEVICE_NAME}")
+message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
+if (BUILD_PYTHON)
+ # building manylinux pkg need this setting to find local libflash-attn.so
+ set(CMAKE_INSTALL_RPATH "$ORIGIN")
+ set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
+endif()
if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build,
@@ -63,7 +68,7 @@ option(ENABLE_CUSPARSELT "build with CUSPARSELT lib" OFF)
option(BUILD_UTEST "build with unit test" ON)
option(BUILD_EXAMPLE "build with examples" ON)
option(BUILD_PYTHON "build with python api" ON)
-option(PACKAGE_RPM "package with rpm " ON)
+option(BUILD_PACKAGE "build cpp package" OFF)
option(MEM_CHECK "check memory" OFF)
option(LOCK_CHECK "check deadlock" OFF)
option(ALWAYS_READ_LOAD_MODEL "load and parse model via every read" OFF)
@@ -212,42 +217,44 @@ if (BUILD_PYTHON)
add_subdirectory(python)
endif()
-
-
-if (PACKAGE_RPM)
-set(CPACK_SYSTEM_NAME "alios7")
-if(CONFIG_HOST_CPU_TYPE STREQUAL "ARM")
- set(CPACK_SYSTEM_ARCHITECTURE "aarch64")
-else()
- set(CPACK_SYSTEM_ARCHITECTURE "x86_64")
-endif()
-
-if (ENABLE_CUDA)
- if(ENABLE_NV_STATIC_LIB)
- set(CPACK_PACKAGE_DEVICE_NAME "cuda-${CUDA_VERSION}-static")
+if (BUILD_PACKAGE)
+ # config system arch
+ if(CONFIG_HOST_CPU_TYPE STREQUAL "ARM")
+ set(CPACK_SYSTEM_ARCHITECTURE "aarch64")
else()
- set(CPACK_PACKAGE_DEVICE_NAME "cuda-${CUDA_VERSION}-shared")
+ set(CPACK_SYSTEM_ARCHITECTURE "x86_64")
endif()
-else()
- set(CPACK_PACKAGE_DEVICE_NAME "cpu")
-endif()
-
-set(CPACK_PACKAGE_NAME "DashInfer")
-set(CPACK_PACKAGE_VERSION ${project_version_in_env})
-set(CPACK_PACKAGE_VENDOR "Alibaba Tongyi")
-set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "DashInfer AllSpark is a LLM inference engine.")
-set(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR})
-set(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR})
-set(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH})
-set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
-set(CPACK_RESOURCE_FILE_README "${CMAKE_CURRENT_SOURCE_DIR}/README.md")
+ if (ENABLE_CUDA)
+ if(ENABLE_NV_STATIC_LIB)
+ set(CPACK_PACKAGE_DEVICE_NAME "cuda-${CUDA_VERSION}-static")
+ else()
+ set(CPACK_PACKAGE_DEVICE_NAME "cuda-${CUDA_VERSION}-shared")
+ endif()
+ else()
+ if (ENABLE_MULTINUMA)
+ set(CPACK_PACKAGE_DEVICE_NAME "cpu-multinuma")
+ else()
+ set(CPACK_PACKAGE_DEVICE_NAME "cpu")
+ endif()
+ endif()
-set(CPACK_PACKAGING_INSTALL_PREFIX "")
-set(CPACK_RPM_PACKAGE_RELOCATABLE ON)
-
-set(CPACK_PACKAGE_FILE_NAME "${CPACK_PACKAGE_NAME}-${CPACK_PACKAGE_VERSION}.${CPACK_PACKAGE_DEVICE_NAME}.${CPACK_SYSTEM_NAME}.${CPACK_SYSTEM_ARCHITECTURE}")
-include(CPack)
+ set(CPACK_PACKAGE_NAME "DashInfer")
+ set(CPACK_PACKAGE_VENDOR "Alibaba Tongyi")
+ set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "DashInfer AllSpark is a LLM inference engine.")
+ set(CPACK_PACKAGE_VERSION ${project_version_in_env})
+ set(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR})
+ set(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR})
+ set(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH})
+ set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
+ set(CPACK_RESOURCE_FILE_README "${CMAKE_CURRENT_SOURCE_DIR}/README.md")
+ set(CPACK_PACKAGING_INSTALL_PREFIX "")
+ set(CPACK_GENERATOR "TGZ")
+ set(CPACK_THREADS 16)
+
+ set(CPACK_PACKAGE_FILE_NAME "${CPACK_PACKAGE_NAME}-${CPACK_PACKAGE_VERSION}.${CPACK_PACKAGE_DEVICE_NAME}.${CPACK_SYSTEM_ARCHITECTURE}")
+
+ INCLUDE(CPack)
endif()
#install
diff --git a/README.md b/README.md
index 84968a22..e93e568f 100644
--- a/README.md
+++ b/README.md
@@ -47,15 +47,6 @@ DashInfer is a highly optimized LLM inference engine with the following core fea
- **Multi-Programming-Language API**: Both C++ and Python interfaces are provided. It is possible to extend C++ interface to Java, Rust and other programming languages, via standard cross-language interfaces.
-
-## Documentation
-- [Release Note](https://dashinfer.readthedocs.io/en/latest/#release-note)
-- [User Manual](https://dashinfer.readthedocs.io/en/latest/)
-- [Installation](docs/EN/installation.md)
-- [C++ Examples](docs/EN/examples_cpp.md)
-- [Python Examples](docs/EN/examples_python.md)
-- [Performance](docs/EN/performance.md)
-
# Supported Hardware and Data Types
## Hardware
@@ -94,86 +85,6 @@ In terms of quantization granularity, there are two types:
- **Per-Channel**: AllSpark's quantization techniques at least adopt the Per-Channel (also known as Per-Token) quantization granularity, and some also provide Sub-Channel quantization granularity. Generally speaking, Per-Channel quantization can meet most accuracy requirements due to its simple implementation and optimal performance. Only when the accuracy of Per-Channel quantization is insufficient should the Sub-Channel quantization strategy be considered.
- **Sub-Channel**: Compared to Per-Channel quantization, Sub-Channel refers to dividing a channel into N groups, and calculating quantization parameters within each group. This quantization granularity typically provides better accuracy, but due to increased implementation complexity, it comes with many limitations. For example, performance may be slightly slower than Per-Channel quantization, and Activation quantization is difficult to implement Sub-Channel quantization due to computational formula constraints (AllSpark's Activation quantization is all Per-Channel).
-# Supported Models
-
-DashInfer support two kind of model load method:
-1. HF format: directly load model from Hugging Face, which provides most convenient method, the model can be downloaded from huggingface or modelscope.
-2. DashInfer format: serialized model file by DashInfer, which provided less python dependency and can be loaded by c++ library.
-
-| Architecture | Models | HuggingFace Models | ModelScope Models |
-|:------------:|:---------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
-| QWenLMHeadModel | Qwen | [Qwen/Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat),
[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat),
[Qwen/Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat), etc. | [qwen/Qwen-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary),
[qwen/Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),
[qwen/Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary), etc. |
-| Qwen2ForCausalLM | Qwen1.5-Qwen2.5 | [Qwen/Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat),
[Qwen/Qwen1.5-1.8B-Chat](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat),
[Qwen/Qwen1.5-4B-Chat](https://huggingface.co/Qwen/Qwen1.5-4B-Chat),
[Qwen/Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat),
[Qwen/Qwen1.5-14B-Chat](https://huggingface.co/Qwen/Qwen1.5-14B-Chat), etc. | [qwen/Qwen1.5-0.5B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary),
[qwen/Qwen1.5-1.8B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-1.8B-Chat/summary),
[qwen/Qwen1.5-4B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat/summary),
[qwen/Qwen1.5-7B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat/summary),
[qwen/Qwen1.5-14B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat/summary), etc. |
-| Qwen2VLForConditionalGeneration | QwenVL | [Qwen/Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat),
[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat),
[Qwen/Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat), etc. | [qwen/Qwen-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary),
[qwen/Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),
[qwen/Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary), etc. |
-| ChatGLMModel | ChatGLM | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) |
-| LlamaForCausalLM | LLaMA-2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),
[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),
[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) |
-| LlamaForCausalLM | LLaMA-3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [modelscope/Meta-Llama-3-8B-Instruct](https://modelscope.cn/models/modelscope/Meta-Llama-3-8B-Instruct/summary) |
-| BaichuanForCausalLM | Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat),
[baichuan-inc/Baichuan2-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat) | [baichuan-inc/Baichuan2-7B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-7B-Chat),
[baichuan-inc/Baichuan2-13B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-13B-Chat) |
-
-# Software Architecture
-
-## Workflow
-
-![Workflow and Dependency](docs/resources/image/workflow-deps.jpg?row=true)
-
-1. **Model Loading**: This procedure involves loading model weights, setting up transformation parameters, and quantization settings. Based on this information, the model is serialized and converted into the DashInfer format (.dimodel, .ditensors, or .asparams, .asmodel). This functionality is accessible exclusively through a Python interface and relies on the PyTorch and transformers libraries to access the weights. The version requirements for PyTorch and transformers may vary from model to model. DashInfer itself does not impose any specific version constraints.
-
-2. **Model Inference**: This step is responsible for executing the model inference using the serialized model with DashInfer, without depending on components like PyTorch. DashInfer employs [DLPack](https://github.com/dmlc/dlpack) format tensors to facilitate interaction with external frameworks, such as PyTorch. Tensors in DLPack format can be manually created or generated through tensor conversion functions provided by deep learning frameworks. Regarding the C++ interface, since most dependencies have been statically linked, it primarily relies on the OpenMP runtime library and C++ system libraries. We applied [control over symbol exports](https://anadoxin.org/blog/control-over-symbol-exports-in-gcc.html/) to ensure that only DashInfer's API interface symbols are visible, thereby preventing version conflicts with existing libraries in the user's system, such as protobuf.
-
-> Note:
-> - After 2.0 version, user rarely needs to care about the model type, which will detected by DashInfer Runtime automatically.
-> - ~~.dimodel, .ditensors is a special model format defined by DashInfer kernel.~~
-> - When utilizing the Python interface, you can combine the code from steps 1 and 2. However, due to the lack of functionality for loading Huggingface models at the C++ level, the C++ interface is limited to conducting inferences with models in the DashInfer format. Therefore, it's essential to serialize the model first using the Python interface before proceeding with the C++ interface.
-
-## GPU and Single-NUMA Architecture
-
-![Single-NUMA Arch](docs/resources/image/arch-single-numa.jpg?row=true)
-
-GPU and Single NUMA CPU Inference share same interface and architecture, in the model inference phase, an inference request can be initiated by passing in input tokens and generation parameters via `StartRequest`, and when the request is successful, the DashInfer engine will return an output queue `ResultQueue` and a control handle `RequestHandle`.
-
-- The `ResultQueue` is used to get output tokens and the status of the generation. DashInfer will **asynchronously** put the generated token into the queue, and tokens in the queue can be fetched either in a blocking (`ResultQueue.Get()`) or non-blocking (`ResultQueue.GetNoWait()`) way.
-
-- The `RequestHandle` is the handle used to manage the request. DashInfer `engine` provides Sync, Stop, and Release primitives for the request specified by the `RequestHandle`. The `SyncRequest` primitive, which returns at the end of generation (when the number of generated tokens reaches the limit, or when an EOS has been generated), is used to simulate the behavior of the synchronous interface.
-
-In GPU and single-NUMA mode, DashInfer Runtime uses multi-threading and a thread pool for scheduling.
-
-## Multi-NUMA Architecture
-
-![Multi-NUMA Arch](docs/resources/image/arch-multi-numa.jpg?row=true)
-
-Due to the inability of some Linux kernels to control CPU affinity at the thread level, running engine on multi-NUMA CPUs may result in remote memory node access, thereby causing a decline in performance. To enable precise control of a thread's CPU affinity, DashInfer multi-NUMA solution employs a multi-process client-server architecture to achieve tensor parallel model inference. On each NUMA node, an independent process runs the server, with each server handling a part of the tensor parallel inference, and the processes use OpenMPI to collaborate (e.g., via the allreduce operation). The client interacts with the servers via gRPC, providing a unique external interface to avoid the need to manage multiple processes when invoking the DashInfer interface.
-
-In terms of API, multi-NUMA and single-NUMA inference need to use different header files and .so libraries (or call different python interfaces). Except for the header and the library, the rest of the interface is consistent and no code changes are required. For details, you can refer to the examples.
-
-- Single-NUMA
- - header: allspark/allspark.h
- - .so library: liballspark_framework.so
- - python API: allspark.Engine()
-- MultiNUMA
- - header: allspark/allspark_client.h
- - .so library: liballspark_client.so
- - python API: allspark.ClientEngine()
-
-> Note: C++ liballspark_framework.so (called for single-NUMA inference) and liballspark_client.so (called for multi-NUMA inference) are mutually exclusive, you cannot link both libraries.
-
-# Performance Test
-
-Please refer to [documentation](docs/EN/performance.md) for detailed performance test results.
-
-The results of this performance test can be reproduced with the scripts in `/examples/python/1_performance`.
-
-# Inference Accuracy
-
-Tested model: [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
-
-| Engine | DataType | MMLU | C-Eval | GSM8K | HumanEval |
-|:------:|:--------:|:----:|:------:|:-----:|:---------:|
-| transformers | BF16 | 55.8 | 59.7 | 50.3 | 37.2 |
-| DashInfer | A16W8 | 55.78 | 61.10 | 51.25 | 37.19 |
-
-- A16W8: The model weight is quantized to 8-bit and is recovered as bfloat16 for matrix multiplication during inference.
-- The results of this accuracy evaluation can be reproduced with the scripts in `/examples/python/2_evaluation`.
-
# Examples
In `/examples` there are examples for C++ and Python interfaces, and please refer to the documentation in `/documents/EN` to run the examples.
@@ -182,36 +93,9 @@ In `/examples` there are examples for C++ and Python interfac
- [Documentation for All Python Examples](docs/EN/examples_python.md)
- [Documentation for C++ Examples](docs/EN/examples_cpp.md)
-## Multi-Modal Model(VLMs)) Support
-
-The VLM Support in [multimodal](multimodal/) folder,
-it's a toolkit to support Vision Language Models (VLMs) inference based on the DashInfer engine. It's compatible with the OpenAI Chat Completion API, supporting text and image/video inputs.
-
-
-# Third-party Dependencies
-
-This subsection lists the third-party dependencies for the different stages of DashInfer.
-
-> Note: These dependency packages are managed through conan and are automatically downloaded when compiling DashInfer.
-
-## Code Compilation Phase
-
-- [conan](https://conan.io/) (1.60.0): For managing C++ third-party dependencies.
-- [cmake](https://cmake.org/) (3.18+): Build system.
-
-## Model Conversion Phase
-
-- [PyTorch](https://pytorch.org/) (CPU): For loading model files, no special version requirements.
-- [transformers](https://github.com/huggingface/transformers): For loading model parameters and tokenizer.
-
-## Model Inference Phase
+## Multi-Modal Model(VLMs) Support
-- [protobuf](https://protobuf.dev/)(3.18.3): For parsing model files.
-- [pybind11](https://github.com/pybind/pybind11)(2.8): For binding python interfaces.
-- [onednn](https://github.com/oneapi-src/oneDNN), [mkl](https://www.intel.com/content/www/us/en/docs/onemkl/get-started-guide/2023-0/overview.html): BLAS libraries, for accelerating GEMM calculations.
-- [openmp](https://www.openmp.org/): A standard parallel programming library.
-- [openmpi](https://www.open-mpi.org/): For implementing multi-NUMA service architecture.
-- [grpc](https://grpc.io/): For implementing multi-NUMA service architecture.
+The VLM Support in [multimodal](multimodal/) folder, it's a toolkit to support Vision Language Models (VLMs) inference based on the DashInfer engine. It's compatible with the OpenAI Chat Completion API, supporting text and image/video inputs.
# Future Plans
- [x] GPU Support
diff --git a/README_CN.md b/README_CN.md
index 1cbabf9d..459b2439 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -35,14 +35,6 @@ DashInfer 是一个高度优化的 LLM 推理引擎,具有以下核心特性
- **OpenAI API 服务器**: DashInfer 可以轻松与 fastChat 配合使用,实现兼容 OpenAI 的 API 服务器。
- **多编程语言 API**: 提供 C++ 和 Python 接口。通过标准的跨语言接口,可以将 C++ 接口扩展到 Java、Rust 等编程语言。
-## 文档
-- [Release Note](https://dashinfer.readthedocs.io/en/latest/#release-note)
-- [User Manual](https://dashinfer.readthedocs.io/en/latest/)
-- [安装](docs/CN/installation.md)
-- [C++示例](docs/CN/examples_cpp.md)
-- [Python示例](docs/CN/examples_python.md)
-- [性能测试](docs/EN/performance.md)
-- [使用魔搭notebook部署](docs/CN/modelscope_notebook.md)
# 硬件支持和数据类型
@@ -71,83 +63,6 @@ DashInfer 为 LLM 权重提供了多种量化技术,例如 int{8,4} 仅权重
- **每通道量化**: AllSpark 的量化技术至少采用了每通道(也称为每 Token)量化粒度,有些还提供了子通道量化粒度。一般而言,每通道量化由于实现简单且性能最佳,通常能满足大多数准确性需求。只有当每通道量化的准确性不足时,才应考虑子通道量化策略。
- **子通道量化**: 与每通道量化相比,子通道量化是指将一个通道划分为 N 组,并在每组内计算量化参数。这种量化粒度通常能提供更好的准确性,但由于实现复杂度增加,带来了许多限制。例如,性能可能比每通道量化稍慢,并且由于计算公式限制,激活量化难以实现子通道量化(AllSpark 的激活量化都是每通道量化)。
-# 模型支持
-DashInfer 支持两种模型加载方式:
-1. **HF 格式**:直接从 Hugging Face 加载模型,这是最方便的方法,模型可以从 Hugging Face 或 ModelScope 下载。
-2. **DashInfer 格式**:由 DashInfer 序列化的模型文件,依赖更少的 Python 组件,可以通过 C++ 库加载。
-
-| Architecture | Models | HuggingFace Models | ModelScope Models |
-|:------------:|:---------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
-| QWenLMHeadModel | Qwen | [Qwen/Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat),
[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat),
[Qwen/Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat), etc. | [qwen/Qwen-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary),
[qwen/Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),
[qwen/Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary), etc. |
-| Qwen2ForCausalLM | Qwen1.5-Qwen2.5 | [Qwen/Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat),
[Qwen/Qwen1.5-1.8B-Chat](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat),
[Qwen/Qwen1.5-4B-Chat](https://huggingface.co/Qwen/Qwen1.5-4B-Chat),
[Qwen/Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat),
[Qwen/Qwen1.5-14B-Chat](https://huggingface.co/Qwen/Qwen1.5-14B-Chat), etc. | [qwen/Qwen1.5-0.5B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-0.5B-Chat/summary),
[qwen/Qwen1.5-1.8B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-1.8B-Chat/summary),
[qwen/Qwen1.5-4B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat/summary),
[qwen/Qwen1.5-7B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat/summary),
[qwen/Qwen1.5-14B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat/summary), etc. |
-| Qwen2VLForConditionalGeneration | QwenVL | [Qwen/Qwen-1_8B-Chat](https://huggingface.co/Qwen/Qwen-1_8B-Chat),
[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat),
[Qwen/Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat), etc. | [qwen/Qwen-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen-1_8B-Chat/summary),
[qwen/Qwen-7B-Chat](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary),
[qwen/Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary), etc. |
-| ChatGLMModel | ChatGLM | [THUDM/glm-4-9b-chat](https://huggingface.co/THUDM/glm-4-9b-chat) | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat/summary) |
-| LlamaForCausalLM | LLaMA-2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf),
[meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [modelscope/Llama-2-7b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-7b-chat-ms/summary),
[modelscope/Llama-2-13b-chat-ms](https://modelscope.cn/models/modelscope/Llama-2-13b-chat-ms/summary) |
-| LlamaForCausalLM | LLaMA-3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [modelscope/Meta-Llama-3-8B-Instruct](https://modelscope.cn/models/modelscope/Meta-Llama-3-8B-Instruct/summary) |
-| BaichuanForCausalLM | Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat),
[baichuan-inc/Baichuan2-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat) | [baichuan-inc/Baichuan2-7B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-7B-Chat),
[baichuan-inc/Baichuan2-13B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-13B-Chat) |
-
-# 软件框架
-
-## 推理流程
-
-![Workflow and Dependency](documents/resources/image/workflow-deps.jpg?row=true)
-
-1. **模型加载**:该过程包括加载模型权重、设置转换参数和量化设置。基于这些信息,模型会被序列化并转换成 DashInfer 格式(.dimodel, .ditensors 或 .asparams, .asmodel) 。此功能仅通过 Python 接口访问,并依赖于 PyTorch 和 transformers 库来访问权重。PyTorch 和 transformers 的版本要求可能因模型而异。DashInfer 本身没有具体的版本限制。
-2. **模型推理**:此步骤负责使用 DashInfer 执行序列化模型的推理,而不依赖于 PyTorch 等组件。DashInfer 采用 [DLPack](https://github.com/dmlc/dlpack) 格式的张量,以便与外部框架(如 PyTorch)进行交互。DLPack 格式的张量可以手动创建,也可以通过深度学习框架提供的张量转换函数生成。对于 C++ 接口,由于大多数依赖项已经被静态链接,它主要依赖于 OpenMP 运行时库和 C++ 系统库。我们应用了 [控制符号导出](https://anadoxin.org/blog/control-over-symbol-exports-in-gcc.html/) 技术,以确保只有 DashInfer 的 API 接口符号是可见的,从而防止与用户系统中的现有库(如 protobuf)发生版本冲突。
-
-> 注意:
-> - 版本 2.0 之后,用户很少需要关心模型类型(在1.0中),它会被 DashInfer Runtime 自动检测。
-> - ~~.dimodel, .ditensors 是 DashInfer 内核定义的一种特殊模型格式。~~
-> - 使用 Python 接口时,可以将步骤 1 和步骤 2 的代码结合起来。然而,由于在 C++ 层面缺乏加载 Huggingface 模型的功能,C++ 接口仅限于使用 DashInfer 格式的模型进行推理。因此,必须先使用 Python 接口序列化模型,然后再进行 C++ 接口的推理。
-## GPU 和 CPU 单NUMA架构图
-
-![Single-NUMA Arch](docs/resources/image/arch-single-numa.jpg?row=true)
-
-GPU 和单 NUMA CPU 推理共享相同的接口和架构。在模型推理阶段,可以通过 `StartRequest` 传入输入标记和生成参数来启动推理请求,当请求成功时,DashInfer 引擎将返回一个输出队列 `ResultQueue` 和一个控制句柄 `RequestHandle`。
-
-- `ResultQueue`用来获取输出token以及生成的状态,推理引擎会**异步**地把生成的token放到该队列中,可以阻塞(`ResultQueue.Get()`)或非阻塞(`ResultQueue.GetNoWait()`)地获取队列中的token。
-
-- `RequestHandle`是用来管理请求的句柄,DashInfer `engine`根据传入的`RequestHandle`实现对指定request的同步(Sync)、停止(Stop)和释放(Release)操作。其中`SyncRequest`操作,会在生成结束(生成的token数达到上限,或产生结束符)后返回,用来模拟同步接口的行为。
-
-在GPU 和 单NUMA的模式下,DashInfer Runtime采用多线程和线程池的结构做调度。
-
-## 多NUMA架构图
-
-![Multi-NUMA Arch](docs/resources/image/arch-multi-numa.jpg?row=true)
-
-由于部分Linux内核无法在线程级别控制CPU亲和性,在多NUMA的CPU上采用单进程推理可能会出现跨NUMA访问内存访问,从而导致性能下降。为了能够精确地控制程序的CPU亲和性,DashInfer的多NUMA方案采用了多进程的client-server架构,实现tensor parallel的模型推理。在每个NUMA节点上,都有一个独立的进程运行DashInfer server,每个server负责一部分的tensor parallel推理,进程间使用OpenMPI进行协同(例如allreduce操作)。DashInfer client通过gRPC与server交互,提供唯一的对外接口,避免在调用DashInfer接口时,需要对多进程进行管理。
-
-在API使用上,多NUMA和单NUMA的推理需要引用不同的头文件、.so库(或调用不同的python接口)。除了引用阶段外,其余接口一致,无需修改代码。具体可以参考examples中的示例。
-
-- 单NUMA
- - 头文件:allspark/allspark.h
- - .so库:liballspark_framework.so
- - python接口:allspark.Engine()
-- 多NUMA
- - 头文件:allspark/allspark_client.h
- - .so库:liballspark_client.so
- - python接口:allspark.ClientEngine()
-
-> 注意:C++的liballspark_framework.so(单NUMA推理时调用)和liballspark_client.so(多NUMA推理时调用)是互斥的,不能同时链接两个库。
-
-# 性能测试
-
-详细的性能测试结果请参考[文档](docs/EN/performance.md)。
-
-该性能测试结果可用`/examples/python/1_performance`中的脚本复现。
-
-# 精度测试
-
-测试模型:[Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
-
-| Engine | DataType | MMLU | C-Eval | GSM8K | HumanEval |
-|:------:|:--------:|:----:|:------:|:-----:|:---------:|
-| transformers | BF16 | 55.8 | 59.7 | 50.3 | 37.2 |
-| DashInfer | A16W8 | 55.78 | 61.10 | 51.25 | 37.19 |
-
-- A16W8:指weight采用8-bit量化,在推理过程中恢复为bfloat16进行矩阵乘法计算;
-- 该精度评测结果,可用`/examples/python/2_evaluation`中的脚本复现。
-
# 示例代码
在`/examples`下提供了C++、python接口的调用示例,请参考`/documents/CN`目录下的文档运行示例。
@@ -156,30 +71,9 @@ GPU 和单 NUMA CPU 推理共享相同的接口和架构。在模型推理阶段
- [所有Python示例文档](docs/CN/examples_python.md)
- [C++示例文档](docs/CN/examples_cpp.md)
-# 依赖库
-
-本小节列出了DashInfer不同阶段的第三方依赖。
-
-> 注:这些依赖包通过conan管理,在编译DashInfer时自动下载。
-
-## 代码编译阶段
-
-- [conan](https://conan.io/) (1.60.0): For managing C++ third-party dependencies.
-- [cmake](https://cmake.org/) (3.18+): Build system.
-
-## 模型转换阶段
-
-- [PyTorch](https://pytorch.org/) (CPU): For reading model files, no special version requirements.
-- [transformers](https://github.com/huggingface/transformers): For loading model parameters and tokenizer.
-
-## 模型推理阶段
+## 多模态模型支持
-- [protobuf](https://protobuf.dev/)(3.18.3): For parsing model files.
-- [pybind11](https://github.com/pybind/pybind11)(2.8): For binding python interfaces.
-- [onednn](https://github.com/oneapi-src/oneDNN), [mkl](https://www.intel.com/content/www/us/en/docs/onemkl/get-started-guide/2023-0/overview.html): BLAS libraries, for accelerating GEMM calculations.
-- [openmp](https://www.openmp.org/): A standard parallel programming library.
-- [openmpi](https://www.open-mpi.org/): For implementing multi-NUMA service architecture.
-- [grpc](https://grpc.io/): For implementing multi-NUMA service architecture.
+[multimodal](multimodal/) 目录下是基于DashInfer实现的多模态模型推理工具,兼容OpenAI Chat Completion API,支持文字、图片、视频输入。
# 未来规划
diff --git a/build.sh b/build.sh
index 08346009..e9a02502 100755
--- a/build.sh
+++ b/build.sh
@@ -20,7 +20,7 @@ NCCL_VERSION="${AS_NCCL_VERSION:-2.23.4}"
system_nv_lib="${AS_SYSTEM_NV_LIB:-OFF}"
build_type="${AS_BUILD_TYPE:-Release}"
cuda_static="${AS_CUDA_STATIC:-OFF}"
-rpm_package="${AS_RPM_PACKAGE:-OFF}"
+build_package="${AS_BUILD_PACKAGE:-ON}"
enable_glibcxx11_abi="${AS_CXX11_ABI:-OFF}"
enable_span_attn="${ENABLE_SPAN_ATTENTION:-ON}"
enable_multinuma="${ENABLE_MULTINUMA:-OFF}"
@@ -81,6 +81,7 @@ export PATH=`pwd`/bin:$PATH
if [ "${with_platform,,}" == "cuda" ]; then
cmake .. \
-DCMAKE_BUILD_TYPE=${build_type} \
+ -DBUILD_PACKAGE=${build_package} \
-DCONFIG_ACCELERATOR_TYPE=CUDA \
-DCONFIG_HOST_CPU_TYPE=X86 \
-DNCCL_VERSION=${NCCL_VERSION} \
@@ -97,9 +98,11 @@ if [ "${with_platform,,}" == "cuda" ]; then
elif [ "${with_platform,,}" == "x86" ]; then
cmake .. \
-DCMAKE_BUILD_TYPE=${build_type} \
+ -DBUILD_PACKAGE=${build_package} \
-DCONFIG_ACCELERATOR_TYPE=NONE \
-DCONFIG_HOST_CPU_TYPE=X86 \
-DENABLE_GLIBCXX11_ABI=${enable_glibcxx11_abi} \
+ -DBUILD_PYTHON=OFF \
-DALLSPARK_CBLAS=MKL \
-DENABLE_CUDA=OFF \
-DENABLE_SPAN_ATTENTION=OFF \
@@ -108,10 +111,12 @@ elif [ "${with_platform,,}" == "x86" ]; then
elif [ "${with_platform,,}" == "armclang" ]; then
cmake .. \
-DCMAKE_BUILD_TYPE=${build_type} \
+ -DBUILD_PACKAGE=${build_package} \
-DCONFIG_ACCELERATOR_TYPE=NONE \
-DCONFIG_HOST_CPU_TYPE=ARM \
-DENABLE_BLADE_AUTH=${enable_blade_auth} \
-DENABLE_GLIBCXX11_ABI=${enable_glibcxx11_abi} \
+ -DBUILD_PYTHON=OFF \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DENABLE_ARMCL=ON \
-DALLSPARK_CBLAS=BLIS \
@@ -134,7 +139,7 @@ make -j16 && make install
if [ $? -eq 0 ]; then
- if [ ${rpm_package} == "ON" ]; then
+ if [ ${build_package} == "ON" ]; then
make package
fi
else
diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake
index fb6bf875..85044546 100644
--- a/cmake/FindNCCL.cmake
+++ b/cmake/FindNCCL.cmake
@@ -5,24 +5,9 @@ if (USE_SYSTEM_NV_LIB)
return()
endif()
include(FindPackageHandleStandardArgs)
-include(FetchContent)
-set(NCCL_VERSION
- "2.11.4"
- CACHE STRING "NCCL VERSION")
-set(NCCL_URL https://github.com/NVIDIA/nccl/archive/refs/tags/v${NCCL_VERSION}-1.tar.gz)
-set(NCCL_PROJECT "extern_nccl")
-FetchContent_Declare(${NCCL_PROJECT} URL ${NCCL_URL})
-message(STATUS "Fetch NCCL from ${NCCL_URL}")
-FetchContent_MakeAvailable(${NCCL_PROJECT})
-
-set(NCCL_ROOT_DIR
- "${${NCCL_PROJECT}_SOURCE_DIR}"
- CACHE PATH "NVIDIA NCCL")
-message(STATUS "NCCL_ROOT_DIR : ${NCCL_ROOT_DIR}")
find_path(
NCCL_INCLUDE_DIR nccl.h
- HINTS ${NCCL_ROOT_DIR}
PATH_SUFFIXES cuda/include include
nccl-${NCCL_VERSION}-cuda-${CUDA_VERSION}/include)
@@ -35,7 +20,6 @@ endif()
message("find nccl with ${NCCL_LIBNAME}")
find_library(
AS_NCCL_LIBRARY ${NCCL_LIBNAME}
- HINTS ${NCCL_ROOT_DIR}
PATH_SUFFIXES lib lib64 nccl-${NCCL_VERSION}-cuda-${CUDA_VERSION}/lib64)
if(ENABLE_NV_STATIC_LIB)
@@ -51,14 +35,11 @@ set_property(TARGET CUDA::${NCCL_LIBNAME} PROPERTY INTERFACE_INCLUDE_DIRECTORIES
${NCCL_INCLUDE_DIR})
# install nccl
-
if(NOT ENABLE_NV_STATIC_LIB)
get_filename_component(NCCL_LIB_DIR ${AS_NCCL_LIBRARY} DIRECTORY)
-install(DIRECTORY ${NCCL_LIB_DIR}/
- DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}
- USE_SOURCE_PERMISSIONS FILES_MATCHING
- PATTERN "*nccl.so*"
-)
+file(GLOB NCCL_LIBS ${NCCL_LIB_DIR}/*nccl.so*)
+install(FILES ${NCCL_LIBS}
+ DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif()
@@ -66,5 +47,5 @@ find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR
AS_NCCL_LIBRARY)
if(NCCL_FOUND)
- message(STATUS "Found NCCL: success , library path : ${AS_NCCL_LIBRARY}")
+ message(STATUS "Found NCCL: success, library path : ${AS_NCCL_LIBRARY}")
endif()
diff --git a/cmake/flash-attention.cmake b/cmake/flash-attention.cmake
index 01dd532d..db646521 100644
--- a/cmake/flash-attention.cmake
+++ b/cmake/flash-attention.cmake
@@ -90,7 +90,7 @@ if (FLASHATTN_USE_STATIC_LIB)
else()
add_library(flash-attention::flash-attn SHARED IMPORTED)
install(FILES ${FLASHATTN_LIBRARY_PATH}/libflash-attn.so
- DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
+ DESTINATION ${CMAKE_INSTALL_LIBDIR})
message(STATUS "libflash-attn.so installing path: ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}")
endif()
diff --git a/cmake/install.cmake b/cmake/install.cmake
index 1c2552e2..d24c330b 100644
--- a/cmake/install.cmake
+++ b/cmake/install.cmake
@@ -1,44 +1,38 @@
# add install target
-SET_TARGET_PROPERTIES(allspark_framework PROPERTIES INSTALL_RPATH "$ORIGIN")
install(DIRECTORY ${PROJECT_SOURCE_DIR}/csrc/interface/
- DESTINATION include/allspark/
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/allspark
USE_SOURCE_PERMISSIONS FILES_MATCHING
PATTERN "*.h"
)
-if (NOT BUILD_PYTHON)
- install(TARGETS allspark_framework_static DESTINATION ${CMAKE_INSTALL_DIR})
-endif()
+install(TARGETS allspark_framework DESTINATION ${CMAKE_INSTALL_LIBDIR})
+install(TARGETS allspark_framework_static DESTINATION ${CMAKE_INSTALL_LIBDIR})
if (ENABLE_MULTINUMA)
- install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/bin/orterun
- DESTINATION bin
- RENAME mpirun)
- install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/bin/allspark_daemon
- DESTINATION bin
- RENAME allspark_daemon)
- SET_TARGET_PROPERTIES(allspark_client PROPERTIES INSTALL_RPATH "$ORIGIN")
- install(TARGETS allspark_client DESTINATION ${CMAKE_INSTALL_DIR})
install(DIRECTORY ${PROJECT_SOURCE_DIR}/csrc/service/
- DESTINATION include/allspark/
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/allspark
USE_SOURCE_PERMISSIONS FILES_MATCHING
PATTERN "allspark_client.h")
+ install(TARGETS allspark_client DESTINATION ${CMAKE_INSTALL_LIBDIR})
+ install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/bin/orterun
+ DESTINATION ${CMAKE_INSTALL_BINDIR}
+ RENAME mpirun)
+ install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/bin/allspark_daemon
+ DESTINATION ${CMAKE_INSTALL_BINDIR}
+ RENAME allspark_daemon)
endif()
if (BUILD_PYTHON)
- if (PYTHON_LIB_DIRS)
- if(NOT ENABLE_NV_STATIC_LIB)
- install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} DESTINATION ${PYTHON_LIB_DIRS} FILES_MATCHING PATTERN "*" PATTERN "libnccl.*" EXCLUDE)
- else()
- install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} DESTINATION ${PYTHON_LIB_DIRS} FILES_MATCHING PATTERN "*")
- endif()
- if (ENABLE_MULTINUMA)
- install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/bin DESTINATION ${PYTHON_LIB_DIRS} USE_SOURCE_PERMISSIONS FILES_MATCHING PATTERN "*")
- SET_TARGET_PROPERTIES(_allspark_client PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_INSTALL_LIBDIR}")
- install(TARGETS _allspark_client DESTINATION ${PYTHON_LIB_DIRS})
- endif()
- SET_TARGET_PROPERTIES(_allspark PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_INSTALL_LIBDIR}")
- install(TARGETS _allspark DESTINATION ${PYTHON_LIB_DIRS})
+if (PYTHON_LIB_DIRS)
+ if(NOT ENABLE_NV_STATIC_LIB)
+ install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} DESTINATION ${PYTHON_LIB_DIRS} FILES_MATCHING PATTERN "*.so" PATTERN "libnccl.*" EXCLUDE)
+ else()
+ install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} DESTINATION ${PYTHON_LIB_DIRS} FILES_MATCHING PATTERN "*.so")
+ endif()
+
+ if (ENABLE_MULTINUMA)
+ install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR} DESTINATION ${PYTHON_LIB_DIRS} USE_SOURCE_PERMISSIONS FILES_MATCHING PATTERN "*")
+ install(TARGETS _allspark_client DESTINATION ${PYTHON_LIB_DIRS})
endif()
-else()
- install(TARGETS allspark_framework DESTINATION ${CMAKE_INSTALL_DIR})
+ install(TARGETS _allspark DESTINATION ${PYTHON_LIB_DIRS})
+endif()
endif()
diff --git a/csrc/service/CMakeLists.txt b/csrc/service/CMakeLists.txt
index 3c64975e..72efc073 100644
--- a/csrc/service/CMakeLists.txt
+++ b/csrc/service/CMakeLists.txt
@@ -32,9 +32,10 @@ target_include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}/../common
${CMAKE_CURRENT_SOURCE_DIR}/../interface)
+set_target_properties(allspark_daemon PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/../${CMAKE_INSTALL_LIBDIR}")
add_library(allspark_client STATIC ${PROTO_SVC_SRCS} ${PROTO_SVC_GRPC_SRC} allspark_client.cpp allspark_client_impl.cpp allspark_service_parallel.cpp)
-target_link_libraries(allspark_client CONAN_PKG::grpc CONAN_PKG::protobuf CONAN_PKG::glog ${THREAD_LIB})
+target_link_libraries(allspark_client allspark_framework CONAN_PKG::grpc CONAN_PKG::protobuf CONAN_PKG::glog ${THREAD_LIB})
if (MEM_CHECK)
target_link_options(allspark_client PUBLIC "-fsanitize=address")
diff --git a/docs/sphinx/devel/source_code_build_en.rst b/docs/sphinx/devel/source_code_build_en.rst
index 9689f844..b5d2a7d4 100644
--- a/docs/sphinx/devel/source_code_build_en.rst
+++ b/docs/sphinx/devel/source_code_build_en.rst
@@ -30,10 +30,10 @@ CUDA
- CUDA sdk version >= 11.4
- cuBLAS: CUDA sdk provided
-conan
+Conan
,,,,,
- + **conan**: C++ package management tools, can be installed by : `pip install conan==1.60.0`, only 1.60.0 is supported.
+ + **conan**: C++ package management tools, can be installed by : ``pip install conan==1.60.0``, only 1.60.0 is supported.
.. note:: if there is any package-not-found issue, please make sure your conan center is available. Reset it with this command: `conan remote add conancenter https://center.conan.io`
@@ -51,7 +51,7 @@ Leak check tool
CPU
,,,
-For multi-NUMA inference, `numactl`, `openmpi` are required:
+For multi-NUMA inference, ``numactl``, ``openmpi`` are required:
- for Ubuntu:
@@ -77,76 +77,96 @@ We have build some Docker image for easier development setup.
.. code-block:: shell
docker run -d --name="dashinfer-dev-cu124-${USER}" \
- --shm-size=8g \
+ --shm-size=8g --gpus all \
--network=host \
- --gpus all \
- -v $(pwd):/root/workspace/HIE-AllSpark \
+ -v $(pwd):/root/workspace/DashInfer \
-w /root/workspace \
-it registry-1.docker.io/dashinfer/dev-centos7-cu124
docker exec -it "dashinfer-dev-cu124-${USER}" /bin/bash
-- YiTian 710 Develoment
+- CPU-only (Linux x86 server)
.. code-block:: shell
docker run -d --name="dashinfer-dev-${USER}" \
--network=host \
- -v $(pwd):/root/workspace/HIE-AllSpark \
+ -v $(pwd):/root/workspace/DashInfer \
+ -w /root/workspace \
+ -it registry-1.docker.io/dashinfer/dev-centos7-x86
+ docker exec -it "dashinfer-dev-${USER}" /bin/bash
+
+- CPU-only (Linux ARM server)
+
+.. code-block:: shell
+
+ docker run -d --name="dashinfer-dev-${USER}" \
+ --network=host \
+ -v $(pwd):/root/workspace/DashInfer \
-w /root/workspace \
-it registry-1.docker.io/dashinfer/dev-centos8-arm
docker exec -it "dashinfer-dev-${USER}" /bin/bash
+.. note:: When creating a container for multi-NUMA inference, ``--cap-add SYS_NICE --cap-add SYS_PTRACE --ipc=host`` arguments are required, because components such as numactl and openmpi need the appropriate permissions to run. If you only need to use the single NUMA API, you may not grant this permission.
+
Build from Source Code
======================
-.. tip:: Here we use CUDA 12.4 as the default CUDA version. If you want to change to a different version, you can use enviroment variable to control CUDA dependency version.
+Build Python Package
+,,,,,,,,,,,,,,,,,,,,
+1. Build python package for CUDA:
-Python package build
-,,,,,,,,,,,,,,,,,,,,
+.. code-block:: bash
-CUDA normal build:
+ cd python
+ AS_CUDA_VERSION="12.4" AS_NCCL_VERSION="2.23.4" AS_CUDA_SM="'80;86;89;90a'" AS_PLATFORM="cuda" \
+ python3 setup.py bdist_wheel
+
+2. Build python package for x86:
.. code-block:: bash
cd python
- AS_CUDA_VERSION="12.4" AS_NCCL_VERSION="2.23.4" AS_CUDA_SM="'80;86;89;90a'" AS_PLATFORM="cuda" python3 setup.py bdist_wheel
+ AS_PLATFORM="x86" python3 setup.py bdist_wheel
-.. note:: The Python build only performs the `conan install` operation at the first time; subsequent builds will not conduct `conan install`. If you encounter issues, consider using `rm -rf ./python/build/temp.*` to re-run the entire process.
+3. Build python package for arm:
-.. note:: Change `AS_RELEASE_VERSION` enviroment var to change package version.
+.. code-block:: bash
-.. note:: To build an x86 or arm CPU only Python package, it's similar to CUDA build, but change the `AS_PLATFORM` environment variable to `x86` or `arm`.
+ cd python
+ AS_PLATFORM="armclang" python3 setup.py bdist_wheel
+.. note::
+ - We use CUDA 12.4 as the default CUDA version. If you want to change to a different version, set ``AS_CUDA_VERSION`` to the target CUDA version.
+ - Set ``AS_RELEASE_VERSION`` enviroment variable to change package version.
+ - Set ``ENABLE_MULTINUMA=ON`` enviroment variable to enable multi-NUMA inference in CPU-only version.
-C++ package build
+Build C++ Libraries
,,,,,,,,,,,,,,,,,,,
-1. C++ lib build for CUDA
+1. Build C++ libraries for CUDA
.. code-block:: bash
- mkdir build;
- AS_CUDA_VERSION="12.4" AS_NCCL_VERSION="2.23.4" AS_CUDA_SM="'80;86;89;90a'" ./build.sh
+ AS_CUDA_VERSION="12.4" AS_NCCL_VERSION="2.23.4" AS_CUDA_SM="'80;86;89;90a'" AS_PLATFORM="cuda" AS_BUILD_PACKAGE="ON" ./build.sh
-2. C++ lib build for x86
+2. Build C++ libraries for x86
.. code-block:: bash
- AS_PLATFORM="x86" ./build.sh
-
-3. C++ lib build for armclang
+ AS_PLATFORM="x86" AS_BUILD_PACKAGE="ON" ./build.sh
-ARM Compile require armcc to archive best performance, setup the compiler in enviroment var.
+3. Build C++ libraries for arm
.. code-block:: bash
export ARM_COMPILER_ROOT=/opt/arm/arm-linux-compiler-24.04_RHEL-8/ # change this path to your own
export PATH=$PATH:$ARM_COMPILER_ROOT/bin
- AS_PLATFORM="armclang" ./build.sh
+
+ AS_PLATFORM="armclang" AS_BUILD_PACKAGE="ON" ./build.sh
Profiling
---------
@@ -156,9 +176,9 @@ Operator Profiling
This section describes how to enable and utilize the operator profiling functionality.
-1. Enable OP profile data collection
+1. Enable OP profiling data collection
-To enable OP profiling, set the environment variable as follows:
+To enable OP profiling, set the environment variable ``AS_PROFILE=ON`` before running DashInfer.
.. code-block:: bash
@@ -166,9 +186,9 @@ To enable OP profiling, set the environment variable as follows:
# Then, run any Python program utilizing the DashInfer Engine.
-2. Print OP profile data
+2. Print OP pro
- To view the profiling information, insert the following function call before deinitializing the engine, replacing model_name with your actual model's name:
+To view the profiling information, call the following function before deinitializing the engine:
.. code-block:: bash
@@ -177,15 +197,14 @@ To enable OP profiling, set the environment variable as follows:
.. tip:: Replace *model_name* with the name of your model.
-3. Analyze OP profile data
+3. Analyze OP profiling data
- An OP profile data report begins with a section header marked by ***** ***** followed by a detailed table. The report consists of three main sections:
+ An OP profiling data report begins with a section header marked by \*\*\* \*\*\* followed by a detailed table. The report consists of three main sections:
- reshape: Statistics on the cost of reshaping inputs for operators.
- alloc: Measures the cost of memory allocation for paged KV cache.
- forward: Focuses on the execution time of operators' forward passes; developers should closely examine this section.
-
Below is an illustration of the table structure and the meaning of each column:
1. **opname**: The name of the operator.
@@ -193,7 +212,6 @@ To enable OP profiling, set the environment variable as follows:
3. **(min/max/ave)**: Minimum, maximum, and average execution times in milliseconds.
4. **total_ms**: The cumulative time spent on this operator.
5. **percentage**: The operator's total time as a percentage of the overall profiling duration.
- 6. **rank**: This column is deprecated.
An example snippet of the profiling output is shown below:
@@ -243,10 +261,10 @@ This section describes how to use controlled Nsys profiling to obtain decoder an
**Steps:**
0. **Disable Warm-up:** Set the environment variable `ALLSPARK_DISABLE_WARMUP=1` to disable the warm-up phase.
-1. **Enable Nsys Profiling Call:** In the file `cuda_context.cpp`, uncomment line 14 to enable the Nsys profiling call.
+1. **Enable Nsys Profiling Call:** Set ``#define ENABLE_NSYS_PROFILE 1`` in file `cuda_context.cpp`.
2. **Model.cpp Configuration:**
- - **Context Phase Profiling:** To profile the context phase, set the variable `PROFILE_CONTEXT_TIME_GPU` to `1`. This will initiate Nsys profiling on the 10th request and terminate the process after one context loop completes.
- - **Generation Phase Profiling:** To profile the generation phase, set the variable `PROFILE_GENERATION_TIME_GPU` to `1`. Profiling will commence after reaching a concurrency (or batch size) specified by `PROFILE_GENERATION_TIME_BS` (adjust this value according to your needs). This allows you to profile the system under a fixed concurrency level.
+ - **Context Phase Profiling:** To profile the context phase, set ``#define PROFILE_CONTEXT_TIME_GPU 1`` in file `model.cpp`. This will initiate Nsys profiling on the 10th request and terminate the process after one context loop completes.
+ - **Generation Phase Profiling:** To profile the generation phase, set ``#define PROFILE_GENERATION_TIME_GPU 1`` in file `model.cpp`. Profiling will commence after reaching a concurrency (or batch size) specified by `PROFILE_GENERATION_TIME_BS` (adjust this value according to your needs). This allows you to profile the system under a fixed concurrency level.
3. **ReCompile:** Recompile your package and install
4. **Start Profiling:** Execute your benchmark or server using the following command:
diff --git a/docs/sphinx/get_started/env_var_options_en.rst b/docs/sphinx/get_started/env_var_options_en.rst
index 49fe8336..9b3c8ced 100644
--- a/docs/sphinx/get_started/env_var_options_en.rst
+++ b/docs/sphinx/get_started/env_var_options_en.rst
@@ -56,7 +56,7 @@ Memory Mangament
store kv cache.
- float
- ``0.0``
- - float value between (0.0,1.0]
+ - float value between [0.0, 1.0]
Logging
=======
diff --git a/docs/sphinx/get_started/install_en.md b/docs/sphinx/get_started/install_en.md
index 346d39d4..a1ec8dd2 100644
--- a/docs/sphinx/get_started/install_en.md
+++ b/docs/sphinx/get_started/install_en.md
@@ -33,14 +33,6 @@ Install python package by following command:
- Install local package: `pip install dashinfer-allspark--xxx.whl`
- Uninstall: `pip uninstall dashinfer-allspark -y`
-## Install C++ Pacakge
-
-for Ubuntu:
-
-- Install: `dpkg -i DashInfer--ubuntu.deb`
-- Uninstall: `dpkg -r DashInfer`
-
-for CentOS:
-
-- Install: `rpm -i DashInfer--centos.rpm`
+## C++ Library
+Download the *.tar.gz package, unzip it, and add it to the compile search path.
diff --git a/docs/sphinx/llm/prefix_caching.rst b/docs/sphinx/llm/prefix_caching.rst
index e346fe0b..842c6fec 100644
--- a/docs/sphinx/llm/prefix_caching.rst
+++ b/docs/sphinx/llm/prefix_caching.rst
@@ -2,4 +2,50 @@
Prefix Caching
=====================
-TODO
+What is Prefix Caching
+**********************
+
+Prefix caching stores kv-caches in GPU or CPU memory for extended periods to reduce redundant calculations. When a new prompt shares the same prefix as a previous one, it can directly use the cached kv-caches, avoiding unnecessary computation and improving performance.
+
+Enable Prefix Caching
+*********************
+
+Runtime Configuration
+---------------------
+
+- ``prefill_cache(enable=True)``: Enables or disables the prefix cache, default value is True.
+- ``prefix_cache_ttl(ttl: int)``: Prefix cache time to live, default value is 300s.
+
+Environment Variable
+--------------------
+
+- ``CPU_CACHE_RATIO``
+ - Description: DashInfer will set CPU_CACHE_RATIO * 100% of the current remaining CPU memory for kv-cache storage, and when CPU_CACHE_RATIO=0, no CPU memory is used to store kv cache.
+ - Data type: float
+ - Default value: ``0.0``
+ - Range: float value between [0.0, 1.0]
+
+
+Performance
+***********
+
+Run `benchmark_throughput.py` in `examples/benchmark` by following command:
+
+.. code-block:: shell
+
+ model=qwen/Qwen2-7B-Instruct && \
+ python3 benchmark_throughput.py --model_path=${model} --modelscope \
+ --engine_max_batch=1 --engine_max_length=4003 --device_ids=0 \
+ --test_qps=250 --test_random_input --test_sample_size=20 --test_max_output=3 \
+ --engine_enable_prefix_cache --prefix_cache_rate_list 0.99,0.9,0.6,0.3
+
+On Nvidia-A100 GPU we get following result:
+
+.. csv-table::
+
+ Batch_size,Request_num,In_tokens,Out_tokens,Avg_context_time(s),Avg_generate_time(s),Prefix_Cache(hit rate)
+ 1,20,4000,3,0.030,0.040,96.0%
+ 1,20,4000,3,0.044,0.040,89.6%
+ 1,20,4000,3,0.121,0.040,57.6%
+ 1,20,4000,3,0.185,0.040,28.8%
+ 1,20,4000,3,0.254,0.040,0.0%
diff --git a/docs/sphinx/llm/runtime_config.rst b/docs/sphinx/llm/runtime_config.rst
index a8540871..59f1fd12 100644
--- a/docs/sphinx/llm/runtime_config.rst
+++ b/docs/sphinx/llm/runtime_config.rst
@@ -78,11 +78,10 @@ Sequence Length and Batch Size
- ``max_prefill_length(length: int)``: Sets the maximum prefill length that will be processed in one context inference; if input length is greater than
this length, it will be process in multiple context inference steps.
-Prefix Cache Configuration
+Prefix Caching Configuration
--------------------------
-- ``prefill_cache(enable=True)``: Enables or disables the prefix cache.
-- ``prefix_cache_ttl(ttl: int)``: Prefix cache time to live, default value is 300s.
+See :doc:`Prefix Caching <../llm/prefix_caching>`.
KV Cache Quantization Configuration
-----------------------------------
diff --git a/examples/cpp/0_basic/example_qwen.cpp b/examples/cpp/0_basic/example_qwen.cpp
index 3813b93f..a0b81c17 100644
--- a/examples/cpp/0_basic/example_qwen.cpp
+++ b/examples/cpp/0_basic/example_qwen.cpp
@@ -76,8 +76,8 @@ int main(int argc, char** argv) {
auto all_exists = check_model_file_exists(model_path, tiktoken_file);
if (!all_exists) return 1;
- std::string dimodel_file = model_path + ".dimodel";
- std::string ditensors_file = model_path + ".ditensors";
+ std::string dimodel_file = model_path + ".asgraph";
+ std::string ditensors_file = model_path + ".asparam";
// create an inference engine instance.
std::unique_ptr as_engine = std::make_unique();
diff --git a/examples/cpp/1_apiserver/apiserver.cpp b/examples/cpp/1_apiserver/apiserver.cpp
index 27e6e43e..1adeeb4f 100644
--- a/examples/cpp/1_apiserver/apiserver.cpp
+++ b/examples/cpp/1_apiserver/apiserver.cpp
@@ -394,8 +394,8 @@ int main(int argc, const char** argv) {
auto all_exists = check_model_file_exists(model_path, tiktoken_file);
if (!all_exists) return 1;
- std::string dimodel_file = model_path + ".dimodel";
- std::string ditensors_file = model_path + ".ditensors";
+ std::string dimodel_file = model_path + ".asgraph";
+ std::string ditensors_file = model_path + ".asparam";
// create an inference engine instance.
setup_tiktoken_tokenizer(tiktoken_file, tokenizer);
diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt
index dde0f332..e639f944 100755
--- a/examples/cpp/CMakeLists.txt
+++ b/examples/cpp/CMakeLists.txt
@@ -12,6 +12,14 @@ set(CMAKE_CXX_EXTENSIONS OFF)
# std::string crash.
# add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
+if (DEFINED ENV{DASHINFER_INCLUDE_PATH})
+ include_directories($ENV{DASHINFER_INCLUDE_PATH})
+endif()
+
+if (DEFINED ENV{DASHINFER_LIBRARY_PATH})
+ link_directories($ENV{DASHINFER_LIBRARY_PATH})
+endif()
+
###########################################
# Example 1: Single NUMA or GPU qwen v1 example.
###########################################
@@ -19,8 +27,7 @@ add_executable(
example_qwen 0_basic/example_qwen.cpp tokenizer/tokenizer.cpp
tokenizer/base64.cpp)
-target_link_libraries(example_qwen PRIVATE allspark_framework
- )
+target_link_libraries(example_qwen PRIVATE allspark_framework)
target_include_directories(example_qwen PRIVATE tokenizer utils)
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index f3b749e9..ec5f016d 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -33,7 +33,7 @@ target_link_libraries(_allspark PRIVATE
allspark_framework_static
CONAN_PKG::protobuf
CONAN_PKG::zlib)
-set_target_properties(_allspark PROPERTIES INSTALL_RPATH "$ORIGIN")
+set_target_properties(_allspark PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/${CMAKE_INSTALL_LIBDIR}")
set_target_properties(_allspark PROPERTIES CXX_STANDARD 17)
if(UNIX AND NOT APPLE)
set(ALLSPARK_LINK_MAP ${PROJECT_SOURCE_DIR}/link_python.map)
@@ -58,7 +58,7 @@ if (ENABLE_MULTINUMA)
-Wl,--no-whole-archive
CONAN_PKG::protobuf)
# target_link_libraries(_allspark_client PRIVATE allspark_client)
- set_target_properties(_allspark_client PROPERTIES INSTALL_RPATH "$ORIGIN")
+ set_target_properties(_allspark_client PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/${CMAKE_INSTALL_LIBDIR}")
set_target_properties(_allspark_client PROPERTIES CXX_STANDARD 17)
if(UNIX AND NOT APPLE)
set(ALLSPARK_LINK_MAP ${PROJECT_SOURCE_DIR}/link_python.map)
diff --git a/scripts/docker/dev_arm_centos8.Dockerfile b/scripts/docker/dev_arm_centos8.Dockerfile
index 17af8d08..04a0f749 100644
--- a/scripts/docker/dev_arm_centos8.Dockerfile
+++ b/scripts/docker/dev_arm_centos8.Dockerfile
@@ -66,6 +66,15 @@ RUN wget "ftp://ftp.gnu.org/gnu/automake/automake-1.15.1.tar.gz" && \
cd automake-1.15.1 && ./configure --prefix=/usr/ && make -j && make install && \
cd .. && rm -rf automake-1.15.1.tar.gz automake-1.15.1
+RUN curl -LO https://github.com/NixOS/patchelf/archive/refs/tags/0.14.5.tar.gz && \
+ tar -xzf 0.14.5.tar.gz && \
+ cd patchelf-0.14.5 && \
+ ./bootstrap.sh && \
+ ./configure && \
+ make install && \
+ cd .. && rm -rf patchelf-0.14.5 0.14.5.tar.gz
+RUN pip3 install auditwheel==6.1.0
+
RUN wget "https://xxxxxx/conan_allspark_source_arm_20241119.tar" && \
tar -xvf conan_allspark_source_arm_20241119.tar && \
mv conan_allspark_source_arm_20241119 /root/.conan && \
diff --git a/scripts/docker/dev_cuda_124.Dockerfile b/scripts/docker/dev_cuda_124.Dockerfile
index 25b5047f..e624fda9 100644
--- a/scripts/docker/dev_cuda_124.Dockerfile
+++ b/scripts/docker/dev_cuda_124.Dockerfile
@@ -72,11 +72,8 @@ RUN conda config --set ssl_verify false
RUN curl -LO https://github.com/Kitware/CMake/releases/download/v3.27.9/cmake-3.27.9-linux-x86_64.sh \
&& bash ./cmake-3.27.9-linux-x86_64.sh --skip-license --prefix=/usr
RUN pip3 install pytest
-RUN curl https://gosspublic.alicdn.com/ossutil/install.sh | bash
RUN conda install -y pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
-RUN ossutil config
-
RUN yum install -y epel-release && yum install -y dnf
RUN dnf makecache && dnf -y install ccache
RUN pip3 install jsonlines GitPython editdistance sacrebleu nltk rouge-score
@@ -91,26 +88,33 @@ RUN yum install -y bash-completion tig
RUN yum install -y build-essential autoconf automake libtool ca-certificates
-RUN curl -LO https://github.com/NixOS/patchelf/archive/refs/tags/0.14.5.tar.gz
-RUN tar -xzf 0.14.5.tar.gz && \
- cd patchelf-0.14.5 && \
- ./bootstrap.sh && \
- ./configure && \
- source /opt/rh/devtoolset-10/enable && make install && \
- rm -rf patchelf-0.14.5 0.14.5.tar.gz && rm -rf patchelf-0.14.5
-
-RUN pip3 install auditwheel==6.1.0
-
RUN yum install -y libtool flex
-
RUN wget "ftp://ftp.gnu.org/gnu/automake/automake-1.15.1.tar.gz" && \
tar -xvf automake-1.15.1.tar.gz && \
cd automake-1.15.1 && ./configure --prefix=/usr/ && make -j && make install && \
cd .. && rm -rf automake-1.15.1.tar.gz automake-1.15.1
-RUN wget "https://xxxxxx/conan_allspark_source_cuda124_20241121.tar" && \
- tar -xvf conan_allspark_source_cuda124_20241121.tar && \
- mv conan_allspark_source_cuda124_20241121 /root/.conan && \
- rm -rf conan_allspark_source_cuda124_20241121.tar
+# git version required by github actions
+RUN yum install -y gettext
+RUN source /root/.bashrc && \
+ wget "https://github.com/git/git/archive/refs/tags/v2.47.0.tar.gz" && \
+ tar -xvf v2.47.0.tar.gz && cd git-2.47.0 && \
+ make configure && ./configure --prefix=/usr && \
+ make -j && make install &&\
+ cd .. && rm -rf v2.47.0.tar.gz git-2.47.0
+
+RUN curl -LO https://github.com/NixOS/patchelf/archive/refs/tags/0.14.5.tar.gz && \
+ tar -xzf 0.14.5.tar.gz && \
+ cd patchelf-0.14.5 && \
+ ./bootstrap.sh && \
+ ./configure && \
+ source /opt/rh/devtoolset-10/enable && make install && \
+ cd .. && rm -rf patchelf-0.14.5 0.14.5.tar.gz
+RUN pip3 install auditwheel==6.1.0
+
+RUN wget "https://xxxxxx/conan_allspark_source_cuda124_20241203_verbose.tar" && \
+ tar -xvf conan_allspark_source_cuda124_20241203_verbose.tar && \
+ mv conan_allspark_source_cuda124_20241203_verbose /root/.conan && \
+ rm -rf conan_allspark_source_cuda124_20241203_verbose.tar
WORKDIR /root/
diff --git a/scripts/docker/dev_x86_centos7.Dockerfile b/scripts/docker/dev_x86_centos7.Dockerfile
index f785f7ee..74bccdd1 100644
--- a/scripts/docker/dev_x86_centos7.Dockerfile
+++ b/scripts/docker/dev_x86_centos7.Dockerfile
@@ -21,9 +21,6 @@ RUN yum install devtoolset-7 -y --nogpgcheck
RUN echo "source /opt/rh/devtoolset-7/enable" >> /root/.bashrc && source /root/.bashrc
ARG PY_VER=3.8
-RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | bash \
- && yum install git-lfs -y
-
RUN curl -LO https://github.com/Kitware/CMake/releases/download/v3.27.9/cmake-3.27.9-linux-x86_64.sh \
&& bash ./cmake-3.27.9-linux-x86_64.sh --skip-license --prefix=/usr
@@ -57,6 +54,7 @@ custom_channels:\n\
RUN conda clean -i -y && conda config --show channels && conda create -y --name ds_py python==${PY_VER} && conda update -n base conda
# RUN conda run python --version && pip3 install --upgrade pip pyOpenSSL==22.0.0 && conda env list
RUN conda run python --version && pip3 install --upgrade pip pyOpenSSL==22.0.0 -i https://mirrors.aliyun.com/pypi/simple && conda env list
+
SHELL ["conda", "run", "-n", "ds_py", "/bin/bash", "-c"]
RUN echo "source activate ds_py" >> /root/.bashrc && source /root/.bashrc
@@ -71,15 +69,37 @@ RUN echo -e "[global]\ntrusted-host=mirrors.aliyun.com\nindex-url = http://mirro
# engine requirements
RUN conda install -y pytorch-cpu -c pytorch
-RUN pip3 install modelscope transformers==4.41.0 protobuf==3.18.3 conan==1.60.0 pytest tokenizers scons wheel pandas tabulate
+RUN pip3 install modelscope transformers protobuf==3.18.3 conan==1.60.0 pytest scons wheel pandas tabulate
-RUN yum install -y libtool flex
+SHELL ["/bin/bash", "-c"]
+RUN yum install -y libtool flex
RUN wget "ftp://ftp.gnu.org/gnu/automake/automake-1.15.1.tar.gz" && \
tar -xvf automake-1.15.1.tar.gz && \
cd automake-1.15.1 && ./configure --prefix=/usr/ && make -j && make install && \
cd .. && rm -rf automake-1.15.1.tar.gz automake-1.15.1
+# git version required by github actions
+RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | bash \
+ && yum install git-lfs -y
+
+RUN yum install -y gettext
+RUN source /root/.bashrc && \
+ wget "https://github.com/git/git/archive/refs/tags/v2.47.0.tar.gz" && \
+ tar -xvf v2.47.0.tar.gz && cd git-2.47.0 && \
+ make configure && ./configure --prefix=/usr && \
+ make -j && make install &&\
+ cd .. && rm -rf v2.47.0.tar.gz git-2.47.0
+
+RUN curl -LO https://github.com/NixOS/patchelf/archive/refs/tags/0.14.5.tar.gz && \
+ tar -xzf 0.14.5.tar.gz && \
+ cd patchelf-0.14.5 && \
+ ./bootstrap.sh && \
+ ./configure && \
+ source /opt/rh/devtoolset-7/enable && make install && \
+ cd .. && rm -rf patchelf-0.14.5 0.14.5.tar.gz
+RUN pip3 install auditwheel==6.1.0
+
RUN wget "https://xxxxxx/conan_allspark_source_x86_20241119.tar" && \
tar -xvf conan_allspark_source_x86_20241119.tar && \
mv conan_allspark_source_x86_20241119 /root/.conan && \
diff --git a/scripts/docker/test_cuda_ubuntu.Dockerfile b/scripts/docker/test_cuda_ubuntu.Dockerfile
new file mode 100644
index 00000000..ef6bc177
--- /dev/null
+++ b/scripts/docker/test_cuda_ubuntu.Dockerfile
@@ -0,0 +1,45 @@
+FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
+
+RUN apt-get update && \
+ apt-get install curl -y
+
+ARG PY_VER=3.10
+
+RUN curl -LO https://repo.anaconda.com/miniconda/Miniconda3-py38_23.11.0-2-Linux-x86_64.sh \
+ && bash Miniconda3-py38_23.11.0-2-Linux-x86_64.sh -p /miniconda -b \
+ && rm -f Miniconda3-py38_23.11.0-2-Linux-x86_64.sh
+ENV PATH=/miniconda/bin:${PATH}
+
+##########################################################################
+# uncomment if want to use anaconda mirror
+##########################################################################
+RUN echo -e "\
+channels:\n\
+ - defaults\n\
+show_channel_urls: true\n\
+default_channels:\n\
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main\n\
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r\n\
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2\n\
+custom_channels:\n\
+ conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\
+ msys2: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\
+ bioconda: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\
+ menpo: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\
+ pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\
+ pytorch-lts: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\
+ simpleitk: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\
+ deepmodeling: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\
+" > /root/.condarc
+
+RUN conda clean -i && conda config --show channels && conda create -y --name test_py python==${PY_VER} && conda update -n base conda
+SHELL ["conda", "run", "-n", "test_py", "/bin/bash", "-c"]
+RUN echo "source activate test_py" >> /root/.bashrc && source /root/.bashrc
+
+##########################################################################
+# uncomment if want to use pip mirror
+##########################################################################
+RUN mkdir -p /root/.pip/
+RUN echo -e "[global]\ntrusted-host=mirrors.aliyun.com\nindex-url = http://mirrors.aliyun.com/pypi/simple\n\n[install]\nuse-wheel=yes" > /root/.pip/pip.conf
+
+WORKDIR /root/
diff --git a/scripts/release/cpp_build_cuda.sh b/scripts/release/cpp_build_cuda.sh
new file mode 100644
index 00000000..3b9232d9
--- /dev/null
+++ b/scripts/release/cpp_build_cuda.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e -x
+
+CUDA_VERSION=$(nvcc --version | grep -oP 'release \K[\d.]+')
+
+mkdir -p local_cuda_libs
+ln -sf /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libnvidia-ml.so local_cuda_libs/libnvidia-ml.so.1
+ln -sf /usr/local/cuda-${CUDA_VERSION}/compat/libcuda.so.1 local_cuda_libs/libcuda.so.1
+export LD_LIBRARY_PATH=${PWD}/local_cuda_libs:${LD_LIBRARY_PATH}
+
+bash build.sh
diff --git a/scripts/release/python_manylinux_build.sh b/scripts/release/python_manylinux_build.sh
index 03347c7f..fad21ba9 100755
--- a/scripts/release/python_manylinux_build.sh
+++ b/scripts/release/python_manylinux_build.sh
@@ -15,13 +15,17 @@ pushd $SCRIPT_DIR
# 捕获arch命令的输出
architecture=$(arch)
+export AS_PYTHON_PKG_NAME="dashinfer-cpu"
+
# 使用if-else结构进行条件判断
if [ "${architecture}" == "aarch64" ]; then
export PLAT=manylinux_2_28_aarch64
export AS_PLATFORM=armclang
+ # export ENABLE_MULTINUMA="ON"
else
export PLAT=manylinux2014_x86_64
export AS_PLATFORM=x86
+ # export ENABLE_MULTINUMA="ON"
fi
if [ -z "$PLAT" ] || [ -z "$AS_PLATFORM" ];
@@ -30,8 +34,6 @@ then
exit 1
fi
-export AS_PYTHON_MANYLINUX=ON
-
function repair_wheel {
wheel="$1"
if ! auditwheel show "$wheel"; then
@@ -57,8 +59,9 @@ build_wheel_for_python() {
conda install pybind11 -y
pip install -r ${REPO_ROOT}/python/requirements_dev_cpu.txt -i https://mirrors.aliyun.com/pypi/simple/
- python ${REPO_ROOT}/python/setup.py bdist_wheel
- pip wheel ${REPO_ROOT}/python --no-deps -w ${REPO_ROOT}/python/wheelhouse/ --log wheel_log.txt
+ ln -sf ${REPO_ROOT}/python/dashinfer .
+ # python ${REPO_ROOT}/python/setup.py bdist_wheel
+ pip wheel ${REPO_ROOT}/python --no-deps -w ${REPO_ROOT}/python/wheelhouse/ --verbose
conda deactivate
# conda remove --name "$env_name" --all -y
@@ -69,7 +72,7 @@ build_wheel_for_python() {
mkdir -p ${REPO_ROOT}/python/wheelhouse/
for python_version in $BUILD_VERSION; do
- build_wheel_for_python ${python_version} 2>&1 | tee whl_build_log_py${python_version//.}.txt
+ build_wheel_for_python ${python_version} 2>&1 | tee wheel_build_log_py${python_version//.}.txt
done
diff --git a/scripts/release/python_manylinux_build_cuda.sh b/scripts/release/python_manylinux_build_cuda.sh
index e5b9d69a..ee65d190 100755
--- a/scripts/release/python_manylinux_build_cuda.sh
+++ b/scripts/release/python_manylinux_build_cuda.sh
@@ -1,9 +1,9 @@
#!/bin/bash
set -e -x
-# ALL_VERSION="3.8 3.9 3.10 3.11"
-ALL_VERSION="3.10"
+ALL_VERSION="3.8 3.9 3.10 3.11"
BUILD_VERSION=${@:-$ALL_VERSION}
+CUDA_VERSION=$(nvcc --version | grep -oP 'release \K[\d.]+')
echo " going to build python wheels with version: ${BUILD_VERSION}"
@@ -19,14 +19,17 @@ architecture=$(arch)
export PLAT=manylinux2014_x86_64
export AS_PLATFORM=cuda
+mkdir -p local_cuda_libs
+ln -sf /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libnvidia-ml.so local_cuda_libs/libnvidia-ml.so.1
+ln -sf /usr/local/cuda-${CUDA_VERSION}/compat/libcuda.so.1 local_cuda_libs/libcuda.so.1
+export LD_LIBRARY_PATH=${PWD}/local_cuda_libs:${LD_LIBRARY_PATH}
+
if [ -z "$PLAT" ] || [ -z "$AS_PLATFORM" ];
then
echo " please set PLAT and AS_PLATFORM env, PLAT can be manylinux_2_28_aarch64 or manylinux2014_x86_64"
exit 1
fi
-export AS_PYTHON_MANYLINUX=ON
-
function repair_wheel {
wheel="$1"
if ! auditwheel show "$wheel"; then
@@ -52,8 +55,9 @@ build_wheel_for_python() {
conda install pybind11 -y
pip install -r ${REPO_ROOT}/python/requirements_dev_cuda.txt -i https://mirrors.aliyun.com/pypi/simple/
- python ${REPO_ROOT}/python/setup.py bdist_wheel
- pip wheel ${REPO_ROOT}/python --no-deps -w ${REPO_ROOT}/python/wheelhouse/ --log wheel_log.txt
+ ln -sf ${REPO_ROOT}/python/dashinfer .
+ # python ${REPO_ROOT}/python/setup.py bdist_wheel
+ pip wheel ${REPO_ROOT}/python --no-deps -w ${REPO_ROOT}/python/wheelhouse/ --verbose
conda deactivate
# conda remove --name "$env_name" --all -y
@@ -64,7 +68,7 @@ build_wheel_for_python() {
mkdir -p ${REPO_ROOT}/python/wheelhouse/
for python_version in $BUILD_VERSION; do
- build_wheel_for_python ${python_version} 2>&1 | tee whl_build_log_py${python_version//.}.txt
+ build_wheel_for_python ${python_version} 2>&1 | tee wheel_build_log_py${python_version//.}.txt
done
diff --git a/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl b/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl
new file mode 100644
index 00000000..526db83e
--- /dev/null
+++ b/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl
@@ -0,0 +1,96 @@
+/***************************************************************************************************
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+#include "cutlass/layout/tensor.h"
+#include "cutlass/arch/mma.h"
+#include "cutlass/conv/convolution.h"
+#include "cutlass/conv/dispatch_policy.hpp"
+#include "cutlass/detail/layout.hpp"
+#include "cutlass/gemm/collective/builders/sm90_common.inl"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass::conv::collective::detail {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Maps a rank-1 cute::Shape<> representing the cluster shape on to the IM2COL TMA atom that should be used with it
+template
+constexpr auto
+sm90_cluster_shape_to_im2col_tma_atom(UnimodalClusterShape unimodal_cluster_shape) {
+ static_assert(cute::rank(unimodal_cluster_shape) == 1,
+ "Use this function to figure out TMA for each mode individually.");
+
+ if constexpr (cute::size(unimodal_cluster_shape) == 1) {
+ return cute::SM90_TMA_LOAD_IM2COL{};
+ }
+ else {
+ return cute::SM90_TMA_LOAD_IM2COL_MULTICAST{};
+ }
+}
+
+// Collective tile traits struct that serves as a type list containing a tensor's mem layouts and atoms for the
+template<
+ class GmemTiledCopy_,
+ class SmemLayout_,
+ class SmemCopyAtom_ = void
+>
+struct Sm90ImplicitGemmTileTraits {
+ using GmemTiledCopy = GmemTiledCopy_;
+ using SmemLayout = SmemLayout_;
+ using SmemCopyAtom = SmemCopyAtom_;
+};
+
+// Accepts a cutlass::layout::Tensor tag and computes the corresponding spatial dimension count
+template
+constexpr int
+gmem_layout_tags_to_spatial_dims() {
+ static_assert(cute::is_same_v);
+ if constexpr (cute::is_same_v) {
+ return 1;
+ }
+ else if constexpr (cute::is_same_v) {
+ return 2;
+ }
+ else if constexpr (cute::is_same_v) {
+ return 3;
+ }
+ else {
+ static_assert(cutlass::detail::dependent_false);
+ }
+}
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass::conv::collective::detail
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl b/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl
new file mode 100644
index 00000000..a08209ef
--- /dev/null
+++ b/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl
@@ -0,0 +1,257 @@
+/***************************************************************************************************
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+#include "cutlass/conv/collective/builders/sm90_common.inl"
+
+// SM90 Collective Builders should be used only starting CUDA 12.0
+#if (__CUDACC_VER_MAJOR__ >= 12)
+#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
+#endif
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass::conv::collective {
+using namespace cute;
+
+namespace detail {
+
+// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
+template
+constexpr int
+compute_stage_count_or_override(StageCount stage_count) {
+ return stages;
+}
+
+// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
+template
+constexpr int
+compute_stage_count_or_override(cute::Int stage_count) {
+ return stages;
+}
+
+// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
+template
+constexpr int
+compute_stage_count_or_override(StageCountAutoCarveout stage_count) {
+ constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
+ constexpr auto a_bits = cute::sizeof_bits_v;
+ constexpr auto b_bits = cute::sizeof_bits_v;
+ constexpr int stage_bytes =
+ cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
+ cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
+ static_cast(mainloop_pipeline_bytes);
+
+ return (CapacityBytes - carveout_bytes) / stage_bytes;
+}
+
+}
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// GMMA_TMA_WS_SS_FPROP
+template <
+ conv::Operator ConvOp,
+ class ElementA,
+ class GmemLayoutA,
+ int AlignmentA,
+ class ElementB,
+ class GmemLayoutB,
+ int AlignmentB,
+ class ElementAccumulator,
+ class TileShape_MNK,
+ class ClusterShape_MNK,
+ class StageCountType,
+ class KernelScheduleType
+>
+struct CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ ConvOp,
+ ElementA,
+ GmemLayoutA,
+ AlignmentA,
+ ElementB,
+ GmemLayoutB,
+ AlignmentB,
+ ElementAccumulator,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ StageCountType,
+ KernelScheduleType,
+ cute::enable_if_t ||
+ cute::is_same_v ||
+ cute::is_same_v>
+> {
+ static_assert(is_static::value);
+ static_assert(is_static::value);
+#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
+ static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n");
+#endif
+ static_assert(cutlass::gemm::collective::detail::is_aligned(),
+ "Should meet TMA alignment requirement\n");
+
+ // For fp32 types, map to tf32 MMA value type
+ using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>;
+ using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>;
+
+ // For fprop, majorA = K, major B = K;
+ // For wgrad, majorA = MN, major B = MN;
+ // For dgrad, majorA = K, major B = MN;
+ static constexpr cute::GMMA::Major GmmaMajorA =
+ (ConvOp == conv::Operator::kWgrad) ? cute::GMMA::Major::MN : cute::GMMA::Major::K;
+ static constexpr cute::GMMA::Major GmmaMajorB =
+ (ConvOp == conv::Operator::kFprop) ? cute::GMMA::Major::K : cute::GMMA::Major::MN;
+
+ using AtomLayoutMNK = cute::conditional_t,
+ Layout>, Layout>>;
+
+ using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
+ ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
+
+ // For wgrad kernel, tensor A uses tma tiled mode and tensor B uses tma im2col mode.
+ using GmemTiledCopyA = cute::conditional_t(ClusterShape_MNK{}))),
+ decltype(cutlass::conv::collective::detail::sm90_cluster_shape_to_im2col_tma_atom(cute::shape<1>(ClusterShape_MNK{})))>;
+ using GmemTiledCopyB = cute::conditional_t(ClusterShape_MNK{}))),
+ decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(cute::shape<0>(ClusterShape_MNK{})))>;
+
+ using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
+ GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
+ using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
+ GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
+
+ static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{});
+
+ using SmemLayoutA = decltype(tile_to_shape(
+ SmemLayoutAtomA{},
+ make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}),
+ Step<_2,_1,_3>{}));
+ using SmemLayoutB = decltype(tile_to_shape(
+ SmemLayoutAtomB{},
+ make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}),
+ Step<_2,_1,_3>{}));
+
+ constexpr static int NumSpatialDimensions = cutlass::conv::collective::detail::gmem_layout_tags_to_spatial_dims();
+
+ using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm<
+ ConvOp, PipelineStages, NumSpatialDimensions, ClusterShape_MNK, KernelScheduleType>;
+
+ using CollectiveOp = CollectiveConv<
+ DispatchPolicy,
+ TileShape_MNK,
+ ElementA,
+ ElementB,
+ TiledMma,
+ detail::Sm90ImplicitGemmTileTraits,
+ detail::Sm90ImplicitGemmTileTraits
+ >;
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// GMMA auto kernel schedule
+template <
+ conv::Operator ConvOp,
+ class ElementA,
+ class GmemLayoutA,
+ int AlignmentA,
+ class ElementB,
+ class GmemLayoutB,
+ int AlignmentB,
+ class ElementAccumulator,
+ class TileShape_MNK,
+ class ClusterShape_MNK,
+ class StageCountType,
+ class KernelScheduleType
+>
+struct CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ ConvOp,
+ ElementA,
+ GmemLayoutA,
+ AlignmentA,
+ ElementB,
+ GmemLayoutB,
+ AlignmentB,
+ ElementAccumulator,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ StageCountType,
+ KernelScheduleType,
+ cute::enable_if_t>
+> {
+ static_assert(is_static::value);
+ static_assert(is_static::value);
+#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
+ static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n");
+#endif
+
+/*
+#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1)))
+ // Cooperative schedule performs best for CUDA Toolkits with version >= 12.1
+
+ // For TileShape_M == 64, choosing KernelTmaWarpSpecialized as the KernelSchedule
+ // Since KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128
+ using KernelWarpSpecializedSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{},
+ KernelImplicitTmaWarpSpecializedSm90PingPong, KernelImplicitTmaWarpSpecializedSm90Cooperative>;
+#else
+ using KernelWarpSpecializedSchedule = KernelImplicitTmaWarpSpecializedSm90;
+#endif
+*/
+ using KernelWarpSpecializedSchedule = KernelImplicitTmaWarpSpecializedSm90;
+
+ using CollectiveOp = typename CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ ConvOp,
+ ElementA,
+ GmemLayoutA,
+ AlignmentA,
+ ElementB,
+ GmemLayoutB,
+ AlignmentB,
+ ElementAccumulator,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ StageCountType,
+ KernelWarpSpecializedSchedule
+ >::CollectiveOp;
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass::conv::collective
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
new file mode 100644
index 00000000..2ca62c97
--- /dev/null
+++ b/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
@@ -0,0 +1,797 @@
+/***************************************************************************************************
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+#include "cute/atom/mma_traits_sm90.hpp"
+#include "cute/atom/mma_traits_sm90_gmma.hpp"
+#include "cute/atom/copy_traits_sm90.hpp"
+
+#include "cutlass/detail/dependent_false.hpp"
+#include "cutlass/detail/layout.hpp"
+#include "cutlass/gemm/collective/builders/sm90_common.inl"
+#include "cutlass/epilogue/dispatch_policy.hpp"
+#include "cutlass/epilogue/collective/collective_epilogue.hpp"
+#include "cutlass/epilogue/collective/builders/sm90_common.inl"
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/epilogue/thread/linear_combination_generic.h"
+#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
+#include "cutlass/epilogue/fusion/callbacks.hpp"
+#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
+
+#if defined(__CUDACC_RTC__)
+#include
+#else
+#include
+#endif
+
+///////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass::epilogue::collective {
+
+///////////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+
+// Returns the parameterized dispatch policy for the TMA epilogue
+template
+constexpr auto
+sm90_get_tma_dispatch_policy() {
+ using namespace cute;
+
+ constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{}));
+ constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v ? 256 : 128);
+ // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
+ constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8);
+ // TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs
+ constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_tma_ptr_array_v;
+ constexpr int StagesD = cute::min(EpiTiles, 2);
+ constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
+ : cute::min(EpiTiles, 4);
+
+ return cute::conditional_t,
+ Sm90PtrArrayTmaWarpSpecialized,
+ Sm90TmaWarpSpecialized>{};
+}
+
+// Returns the smem layout atom to be used for C or D matrix
+template
+constexpr auto
+sm90_get_epilogue_smem_swizzle_layout_atom() {
+ using namespace cute;
+
+ // ColMajor C/D (M-major)
+ if constexpr (cutlass::gemm::detail::is_major<0>(GmemStrideType{})) {
+ return cutlass::gemm::collective::detail::ss_smem_selector<
+ cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{}))
+ >();
+ }
+ // RowMajor C/D (N-major)
+ else if constexpr (cutlass::gemm::detail::is_major<1>(GmemStrideType{})) {
+ return cutlass::gemm::collective::detail::ss_smem_selector<
+ cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{}))
+ >();
+ }
+ else {
+ static_assert(cutlass::detail::dependent_false, "Unsupported gmem layout.");
+ }
+}
+
+// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
+template
+constexpr auto
+sm90_compute_tile_shape_or_override() {
+ if constexpr (cute::is_same_v) {
+ auto epi_tile = [&] () {
+ if constexpr (detail::sm90_is_cooperative_v) {
+ auto tile_m = cute::min(_128{}, size<0>(TileShape_MNK{}));
+ auto tile_n = cute::min(_32{}, size<1>(TileShape_MNK{}));
+ return make_shape(tile_m, tile_n);
+ }
+ else if constexpr (detail::sm90_is_warp_specialized_v) {
+ constexpr int N_perf = sizeof_bits_v == 8 ? 64 : 32;
+ auto tile_m = cute::min(_64{}, size<0>(TileShape_MNK{}));
+ auto tile_n = cute::min(Int{}, size<1>(TileShape_MNK{}));
+ return make_shape(tile_m, tile_n);
+ }
+ else {
+ static_assert(cutlass::detail::dependent_false, "Unsupported schedule.");
+ }
+ }();
+
+ return cute::transform(epi_tile, seq<0,1>{},
+ [] (auto epi_tiler, auto I) {
+ auto cta_tiler = make_layout(get(TileShape_MNK{}));
+ // This is a multimodal CTA tiler, transform before returning
+ if constexpr (depth(cta_tiler) > 0) {
+ // This is an implicit multimodal tiler, match profile and return
+ if constexpr (tuple_size_v == 1) {
+ return make_tile(epi_tiler);
+ }
+ // This is an explicit multimodal tiler, compose out epi tiler
+ else {
+ return composition(cta_tiler, epi_tiler);
+ }
+ }
+ // This is a flat CTA tiler, no need for transformation
+ else {
+ return epi_tiler;
+ }
+ });
+ }
+ else if constexpr (cute::is_tuple::value) {
+ EpilogueTileType epi_tile;
+ constexpr int M = size<0>(shape(epi_tile));
+ constexpr int N = size<1>(shape(epi_tile));
+
+ static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape");
+ static_assert(M == 64 && detail::sm90_is_warp_specialized_v ||
+ M == 128 && detail::sm90_is_cooperative_v, "Unsupported tile shape");
+ static_assert(N % 16 == 0, "Unsupported tile shape");
+
+ return epi_tile;
+ }
+ else {
+ static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType.");
+ }
+}
+
+// callbacks builder with TMA aux out
+template <
+ int StagesC,
+ int StagesD,
+ int FragmentSize,
+ bool ReuseSmemC,
+ bool DelayTmaStore,
+ class FusionOp,
+ class TileShape_MNK,
+ class EpilogueTile_MN,
+ class ElementAccumulator
+>
+struct CallbacksBuilder<
+ Sm90TmaWarpSpecialized,
+ FusionOp,
+ TileShape_MNK,
+ EpilogueTile_MN,
+ ElementAccumulator,
+ cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor
+ && not cute::is_subbyte_v>
+> {
+ using GmemStrideTypeAux = gemm::TagToStrideC_t;
+ using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<
+ GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>());
+ using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator<
+ GmemStrideTypeAux, typename FusionOp::ElementAux>());
+ using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source<
+ GmemStrideTypeAux, typename FusionOp::ElementAux>());
+ using SmemCopyOpAux = cute::conditional_t;
+
+ using Callbacks = fusion::FusionCallbacks<
+ Sm90TmaWarpSpecialized,
+ FusionOp, TileShape_MNK, EpilogueTile_MN,
+ SmemLayoutAtomAux, SmemCopyOpAux
+ >;
+};
+
+template <
+ int StagesC,
+ int StagesD,
+ int FragmentSize,
+ bool ReuseSmemC,
+ bool DelayTmaStore,
+ class FusionOp,
+ class TileShape_MNK,
+ class EpilogueTile_MN,
+ class ElementAccumulator
+>
+struct CallbacksBuilder<
+ Sm90TmaWarpSpecialized,
+ FusionOp,
+ TileShape_MNK,
+ EpilogueTile_MN,
+ ElementAccumulator,
+ cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor
+ && sizeof_bits_v == 1>
+> {
+ using Callbacks = fusion::FusionCallbacks<
+ Sm90TmaWarpSpecialized,
+ FusionOp, TileShape_MNK, EpilogueTile_MN,
+ Layout<_1,_0>, DefaultCopy // aux bit tensor doesn't use smem
+ >;
+};
+
+// Helper for building TMA warp-specialized collective epilogues, specialized by
+// the fusion operation performed and the dispatch policy to use.
+template <
+ class TileShape_MNK,
+ class EpilogueTile_MN,
+ class ElementAccumulator,
+ class ElementCompute,
+ class ElementC_,
+ class GmemLayoutTagC_,
+ int AlignmentC,
+ class ElementD_,
+ class GmemLayoutTagD,
+ int AlignmentD,
+ class FusionOpOrCallbacks,
+ class DispatchPolicy
+>
+struct Sm90TmaBuilderImpl {
+ // Passing void D disables destination store + smem allocation
+ using ElementD = cute::conditional_t,
+ fusion::get_element_aux_t, ElementD_>;
+
+ // Passing void C disables source load + smem allocation
+ using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages
+ using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>;
+
+ using GmemStrideTypeC = cutlass::detail::TagToStrideC_t;
+ using GmemStrideTypeD = cutlass::detail::TagToStrideC_t;
+
+ using CopyOpS2G = cute::conditional_t,
+ SM90_TMA_STORE_IM2COL,
+ SM90_TMA_STORE
+ >;
+ using CopyOpG2S = cute::conditional_t,
+ SM90_TMA_LOAD_IM2COL,
+ SM90_TMA_LOAD
+ >;
+
+ // Get the smallest tiled copy we can use to retile the accumulators
+ using CopyAtomC = Copy_Atom;
+
+ using FusionDispatchPolicy = Sm90TmaWarpSpecialized;
+
+ // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks
+ // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination
+ using FusionCallbacks =
+ typename CallbacksBuilder<
+ FusionDispatchPolicy,
+ FusionOpOrCallbacks,
+ TileShape_MNK,
+ EpilogueTile_MN,
+ ElementAccumulator
+ >::Callbacks;
+
+ using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue<
+ DispatchPolicy,
+ TileShape_MNK,
+ EpilogueTile_MN,
+ ElementC_, // Need to pass void through to expose via GemmUniversal
+ GmemStrideTypeC,
+ ElementD_,
+ GmemStrideTypeD,
+ FusionCallbacks,
+ CopyOpG2S,
+ decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()),
+ decltype(detail::sm90_get_smem_load_op_for_source()),
+ CopyOpS2G,
+ decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()),
+ decltype(detail::sm90_get_smem_store_op_for_accumulator()),
+ CopyAtomC
+ >;
+};
+
+///////////////////////////////////////////////////////////////////////////////
+// Descriptor classes for defining EVT nodes
+// Some of the epilogue visitor nodes require non-intuitive template arguments
+// such as CopyOpS2R for AuxLoad node. Traditionaly, these are resolved by the
+// builder classes. Here we provide a set of descriptor classes that resolve
+// these template arguments from more intuitive types such as Stride, Layout
+
+// Get TileShape, EpilogueTile, Dispatch Policy, StagesC, and STagesD
+template<
+ typename TileShape_MNK,
+ typename EpilogueTileType,
+ typename ElementC,
+ typename ElementD,
+ typename Schedule
+>
+struct EpilogueDescriptor {
+ using TileShape = TileShape_MNK;
+ using EpilogueTile =
+ decltype(
+ detail::sm90_compute_tile_shape_or_override<
+ ElementD, EpilogueTileType, Schedule, TileShape_MNK
+ >()
+ );
+ using DispatchPolicy =
+ decltype(
+ detail::sm90_get_tma_dispatch_policy<
+ TileShape_MNK, EpilogueTile,
+ ElementC, ElementD, Schedule
+ >()
+ );
+ constexpr static int StagesC = DispatchPolicy::StagesC;
+ constexpr static int StagesD = DispatchPolicy::StagesD;
+};
+
+// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node
+template<
+ typename EpilogueDescriptor,
+ typename StrideOrLayoutTag,
+ typename ElementAux
+>
+struct AuxLoadDescriptor {
+ constexpr static int Stages = EpilogueDescriptor::StagesC;
+ using EpilogueTile = typename EpilogueDescriptor::EpilogueTile;
+ using Element = ElementAux;
+ using Stride = cutlass::detail::TagToStrideC_t;
+ using SmemLayoutAtom =
+ decltype(
+ detail::sm90_get_epilogue_smem_swizzle_layout_atom<
+ Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile
+ >()
+ );
+ using CopyOpS2R =
+ decltype(detail::sm90_get_smem_load_op_for_source());
+};
+
+// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node
+template<
+ typename EpilogueDescriptor,
+ typename StrideOrLayoutTag,
+ typename ElementAux
+>
+struct AuxStoreDescriptor {
+ constexpr static int Stages = EpilogueDescriptor::StagesD;
+ using EpilogueTile = typename EpilogueDescriptor::EpilogueTile;
+ using Element = ElementAux;
+ using Stride = cutlass::detail::TagToStrideC_t;
+ using SmemLayoutAtom =
+ decltype(
+ detail::sm90_get_epilogue_smem_swizzle_layout_atom<
+ Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile
+ >()
+ );
+ using CopyOpR2S =
+ decltype(detail::sm90_get_smem_store_op_for_accumulator());
+};
+
+} // namespace detail
+
+///////////////////////////////////////////////////////////////////////////////
+
+// No-smem builder
+template <
+ class TileShape_MNK,
+ class ClusterShape_MNK,
+ class EpilogueTileType,
+ class ElementAccumulator,
+ class ElementCompute,
+ class ElementC_,
+ class GmemLayoutTagC_,
+ int AlignmentC,
+ class ElementD,
+ class GmemLayoutTagD,
+ int AlignmentD,
+ class Schedule,
+ FloatRoundStyle RoundStyle
+>
+struct CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC_,
+ GmemLayoutTagC_,
+ AlignmentC,
+ ElementD,
+ GmemLayoutTagD,
+ AlignmentD,
+ Schedule,
+ fusion::LinearCombination,
+ cute::enable_if_t ||
+ cute::is_same_v >> {
+
+ // Passing void C disables source load
+ using ElementC = cute::conditional_t,
+ ElementD, ElementC_>; // prevents cute breakages
+ using GmemLayoutTagC = cute::conditional_t,
+ GmemLayoutTagD, GmemLayoutTagC_>;
+ static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ?
+ thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
+
+ static constexpr int FragmentSize = 1;
+ using ThreadOp = thread::LinearCombination<
+ ElementD, FragmentSize, ElementAccumulator, ElementCompute,
+ ScaleType, RoundStyle, ElementC>;
+
+ using CollectiveOp = cute::conditional_t<
+ cute::is_same_v,
+ cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
+ cutlass::epilogue::collective::DefaultEpilogue<
+ cutlass::detail::TagToStrideC_t,
+ cutlass::detail::TagToStrideC_t,
+ ThreadOp,
+ cutlass::gemm::EpilogueDefault>>,
+ // Epilogue for Ptr-Array and Grouped Gemm
+ cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
+ cutlass::epilogue::collective::DefaultEpilogueArray<
+ cutlass::detail::TagToStrideC_t,
+ cutlass::detail::TagToStrideC_t,
+ ThreadOp,
+ Schedule>>
+ >;
+};
+
+// Tma warp-specialized builder
+template <
+ class TileShape_MNK,
+ class ClusterShape_MNK,
+ class EpilogueTileType,
+ class ElementAccumulator,
+ class ElementCompute,
+ class ElementC,
+ class GmemLayoutTagC,
+ int AlignmentC,
+ class ElementD_,
+ class GmemLayoutTagD,
+ int AlignmentD,
+ class Schedule,
+ class FusionOperation
+>
+struct CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ GmemLayoutTagC,
+ AlignmentC,
+ ElementD_,
+ GmemLayoutTagD,
+ AlignmentD,
+ Schedule,
+ FusionOperation,
+ cute::enable_if_t ||
+ cute::is_same_v ||
+ cute::is_same_v >> {
+private:
+ using ElementD = cute::conditional_t,
+ fusion::get_element_aux_t, ElementD_>;
+ using EpilogueTile_MN =
+ decltype(detail::sm90_compute_tile_shape_or_override());
+ using DispatchPolicy =
+ decltype(detail::sm90_get_tma_dispatch_policy());
+
+public:
+ using CollectiveOp =
+ typename detail::Sm90TmaBuilderImpl<
+ TileShape_MNK,
+ EpilogueTile_MN,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ GmemLayoutTagC,
+ AlignmentC,
+ ElementD_,
+ GmemLayoutTagD,
+ AlignmentD,
+ FusionOperation,
+ DispatchPolicy
+ >::CollectiveOp;
+};
+
+// Auto builder
+template <
+ class TileShape_MNK,
+ class ClusterShape_MNK,
+ class EpilogueTileType,
+ class ElementAccumulator,
+ class ElementCompute,
+ class ElementC,
+ class GmemLayoutTagC,
+ int AlignmentC,
+ class ElementD,
+ class GmemLayoutTagD,
+ int AlignmentD,
+ class FusionOperation
+>
+struct CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ GmemLayoutTagC,
+ AlignmentC,
+ ElementD,
+ GmemLayoutTagD,
+ AlignmentD,
+ EpilogueScheduleAuto,
+ FusionOperation,
+ void> {
+private:
+ static_assert(cute::is_same_v>,
+ "Auto schedule doesn't support fusion. Use one of the TmaWarpSpecialized schedules instead.");
+
+ // Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance)
+ // since TMA epilogues are not compatible with non-TMA non-WS mainloops
+ using EpilogueSchedule = NoSmemWarpSpecialized;
+ using _CollectiveBuilder = CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ GmemLayoutTagC,
+ AlignmentC,
+ ElementD,
+ GmemLayoutTagD,
+ AlignmentD,
+ EpilogueSchedule,
+ FusionOperation
+ >;
+
+public:
+ using CollectiveOp = typename _CollectiveBuilder::CollectiveOp;
+};
+
+// DEPRECATED Tma warp-specialized builder for elementwise fusion
+template <
+ class TileShape_MNK,
+ class ClusterShape_MNK,
+ class EpilogueTileType,
+ class ElementAccumulator,
+ class ElementCompute,
+ class ElementC,
+ class GmemLayoutTagC,
+ int AlignmentC,
+ class ElementD,
+ class GmemLayoutTagD,
+ int AlignmentD,
+ class Schedule,
+ class UnusedFusionOp
+>
+struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]]
+CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ GmemLayoutTagC,
+ AlignmentC,
+ ElementD,
+ GmemLayoutTagD,
+ AlignmentD,
+ Schedule,
+ UnusedFusionOp,
+ cute::enable_if_t ||
+ cute::is_base_of_v >> {
+private:
+ using FusionOp =
+ fusion::LinCombEltAct;
+ using ImplSchedule =
+ cute::conditional_t,
+ TmaWarpSpecialized, TmaWarpSpecializedCooperative>;
+
+public:
+ using CollectiveOp =
+ typename CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC,
+ GmemLayoutTagC,
+ AlignmentC,
+ ElementD,
+ GmemLayoutTagD,
+ AlignmentD,
+ ImplSchedule,
+ FusionOp
+ >::CollectiveOp;
+};
+
+// DEPRECATED Tma warp-specialized builder for bias + elementwise fusion
+template <
+ class TileShape_MNK,
+ class ClusterShape_MNK,
+ class EpilogueTileType,
+ class ElementAccumulator,
+ class ElementCompute,
+ class ElementC_,
+ class GmemLayoutTagC_,
+ int AlignmentC,
+ class ElementD,
+ class GmemLayoutTagD,
+ int AlignmentD,
+ class Schedule,
+ class UnusedFusionOp
+>
+struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltAct or fusion::LinCombPerRowBiasEltActAux instead")]]
+CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC_,
+ GmemLayoutTagC_,
+ AlignmentC,
+ ElementD,
+ GmemLayoutTagD,
+ AlignmentD,
+ Schedule,
+ UnusedFusionOp,
+ cute::enable_if_t ||
+ cute::is_base_of_v >> {
+private:
+ using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override<
+ ElementD, EpilogueTileType, Schedule, TileShape_MNK>());
+ // MSVC doesn't seem to be able to deduce DispatchPolicy correctly if it's
+ // defined as decltype of a detail::sm90_get_tma_dispatch_policy call.
+ // Instead, we paste in the contents of that function. A natural refactoring
+ // would be to create a type alias in the detail namespace.
+ using DispatchPolicy = Sm90TmaWarpSpecialized<
+ /* StagesC = */ size(shape_div(take<0, 2>(TileShape_MNK{}), EpilogueTile_MN{})),
+ /* StagesD = */ 2,
+ /* FragmentSize = */ size(EpilogueTile_MN{}) / (detail::sm90_is_cooperative_v ? 256 : 128),
+ /* ReuseSmemC = */ sizeof_bits_v == sizeof_bits_v,
+ false
+ >;
+
+ using GmemStrideTypeAux = gemm::TagToStrideC_t;
+ using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<
+ GmemStrideTypeAux, typename Schedule::ElementT, EpilogueTile_MN>());
+ using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator<
+ GmemStrideTypeAux, typename Schedule::ElementT>());
+ using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux<
+ GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute,
+ typename Schedule::ElementT, typename Schedule::ElementBias, ElementC_, ElementCompute
+ >;
+ using FusionCallbacksAux = fusion::FusionCallbacks<
+ DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux
+ >;
+
+ using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct<
+ Schedule::template ActivationFunctor, ElementD, ElementCompute,
+ typename Schedule::ElementBias, ElementC_, ElementCompute
+ >;
+ using FusionCallbacksNoAux = fusion::FusionCallbacks<
+ DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN
+ >;
+
+ using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages
+ using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>;
+
+ using GmemStrideTypeC = gemm::TagToStrideC_t;
+ using GmemStrideTypeD = gemm::TagToStrideC_t;
+
+ // Get the smallest tiled copy we can use to retile the accumulators
+ using CopyAtomC = Copy_Atom;
+
+public:
+ using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise<
+ DispatchPolicy::StagesC,
+ DispatchPolicy::StagesD,
+ DispatchPolicy::FragmentSize,
+ TileShape_MNK,
+ EpilogueTile_MN,
+ ElementC_, // Need to pass void through to expose via GemmUniversal
+ GmemStrideTypeC,
+ ElementD,
+ GmemStrideTypeD,
+ cute::conditional_t,
+ SM90_TMA_LOAD,
+ decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()),
+ decltype(detail::sm90_get_smem_load_op_for_source()),
+ SM90_TMA_STORE,
+ decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()),
+ decltype(detail::sm90_get_smem_store_op_for_accumulator()),
+ CopyAtomC
+ >;
+};
+
+// CollectiveBuilder that transposed epilogue below is used for sm90 gmma RS TT kernels
+// since swapping NNN kernels input matrix and transposing its output at the same time then
+// we can get TTN kernel.
+template <
+ class TileShape_MNK,
+ class ClusterShape_MNK,
+ class EpilogueTileType,
+ class ElementAccumulator,
+ class ElementCompute,
+ class ElementC_,
+ class GmemLayoutTagC_,
+ int AlignmentC,
+ class ElementD,
+ class GmemLayoutTagD,
+ int AlignmentD,
+ FloatRoundStyle RoundStyle
+>
+struct CollectiveBuilder<
+ arch::Sm90,
+ arch::OpClassTensorOp,
+ TileShape_MNK,
+ ClusterShape_MNK,
+ EpilogueTileType,
+ ElementAccumulator,
+ ElementCompute,
+ ElementC_,
+ GmemLayoutTagC_,
+ AlignmentC,
+ ElementD,
+ GmemLayoutTagD,
+ AlignmentD,
+ cutlass::gemm::EpilogueTransposed,
+ fusion::LinearCombination,
+ void> {
+ // Passing void C disables source load
+ using ElementC = cute::conditional_t,
+ ElementD, ElementC_>; // prevents cute breakages
+ using GmemLayoutTagC = cute::conditional_t,
+ GmemLayoutTagD, GmemLayoutTagC_>;
+ static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ?
+ thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
+
+ static constexpr int FragmentSize = 1;
+ using ThreadOp = thread::LinearCombination<
+ ElementD, FragmentSize, ElementAccumulator, ElementCompute,
+ ScaleType, RoundStyle, ElementC>;
+
+ using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
+ cutlass::epilogue::collective::DefaultEpilogue<
+ cutlass::detail::TagToStrideC_t,
+ cutlass::detail::TagToStrideC_t,
+ ThreadOp,
+ cutlass::gemm::EpilogueTransposed>
+ >;
+};
+
+///////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass::epilogue::collective
diff --git a/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl b/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl
new file mode 100644
index 00000000..cd2639c5
--- /dev/null
+++ b/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl
@@ -0,0 +1,80 @@
+/***************************************************************************************************
+ * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+///////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass::epilogue::collective::detail {
+
+///////////////////////////////////////////////////////////////////////////////
+
+// Selects the largest vectorized smem store atom available
+template
+constexpr auto
+sm90_get_smem_store_op_for_accumulator() {
+ using namespace cute;
+
+ if constexpr (sizeof(ElementD) == 2 && size<0>(GmemStrideTypeD{}) == 1) {
+ return SM90_U16x8_STSM_T{};
+ }
+ else if constexpr (sizeof(ElementD) == 2 && size<1>(GmemStrideTypeD{}) == 1) {
+ return SM90_U32x4_STSM_N{};
+ }
+ else {
+ // auto-vectorizing store
+ return AutoVectorizingCopyWithAssumedAlignment{};
+ }
+}
+
+// Selects the largest vectorized smem load atom available
+template
+constexpr auto
+sm90_get_smem_load_op_for_source() {
+ using namespace cute;
+
+ // Reuse the logic from smem store selector
+ using SmemStoreOp = decltype(sm90_get_smem_store_op_for_accumulator());
+
+ if constexpr (cute::is_same_v) {
+ return SM75_U16x8_LDSM_T{};
+ }
+ else if constexpr (cute::is_same_v) {
+ return SM75_U32x4_LDSM_N{};
+ }
+ else {
+ // auto-vectorizing load
+ return AutoVectorizingCopyWithAssumedAlignment<128>{};
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass::epilogue::collective::detail
diff --git a/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl b/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl
new file mode 100644
index 00000000..298793e8
--- /dev/null
+++ b/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl
@@ -0,0 +1,364 @@
+/***************************************************************************************************
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+#include "cutlass/arch/mma.h"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/dispatch_policy.hpp"
+#include "cutlass/detail/layout.hpp"
+#include "cutlass/detail/collective.hpp"
+#include "cutlass/detail/dependent_false.hpp"
+
+#include "cute/atom/mma_traits_sm90_gmma.hpp"
+#include "cute/atom/copy_traits_sm90_tma.hpp"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass::gemm::collective {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+
+//
+// Some named constants
+//
+constexpr int tma_alignment_bytes = 16;
+constexpr int cp_async_min_alignment_bytes = 4;
+constexpr int sm90_smem_capacity_bytes = 232448;
+
+// Maps 2.x A matrix layout tag to respective GMMA major mode enum
+template
+constexpr cute::GMMA::Major
+gmma_ss_tag_to_major_A() {
+ // MN major mode is only valid for non-TF32, non-int and non-fp8 MMAs
+ if constexpr (cutlass::gemm::detail::is_mn_major_A() &&
+ not cute::is_same_v &&
+ sizeof(ElementA) != 1) {
+ return cute::GMMA::Major::MN;
+ }
+ else {
+ return cute::GMMA::Major::K;
+ }
+}
+
+// Maps 2.x B matrix layout tag to respective GMMA major mode enum
+template
+constexpr cute::GMMA::Major
+gmma_ss_tag_to_major_B() {
+ // MN major mode is only valid for non-TF32, non-int and non-fp8 MMAs
+ if constexpr (cutlass::gemm::detail::is_mn_major_B() &&
+ not cute::is_same_v &&
+ sizeof(ElementB) != 1) {
+ return cute::GMMA::Major::MN;
+ }
+ else {
+ return cute::GMMA::Major::K;
+ }
+}
+
+template
+constexpr cute::GMMA::Major
+gmma_rs_tag_to_major_A() {
+ // MN major mode is only valid for non-TF32 and non-int MMAs
+ if constexpr (cutlass::gemm::detail::is_mn_major_A()) {
+ return cute::GMMA::Major::MN;
+ }
+ else {
+ return cute::GMMA::Major::K;
+ }
+}
+
+template
+constexpr cute::GMMA::Major
+gmma_rs_tag_to_major_B() {
+ // MN major mode is only valid for non-TF32 and non-int MMAs
+ if constexpr (cutlass::gemm::detail::is_mn_major_B()) {
+ return cute::GMMA::Major::MN;
+ }
+ else {
+ return cute::GMMA::Major::K;
+ }
+}
+// Maps a rank-1 cute::Shape<> representing the cluster shape on to the TMA atom that should be used with it
+template
+constexpr auto
+sm90_cluster_shape_to_tma_atom(UnimodalClusterShape) {
+ static_assert(cute::rank(UnimodalClusterShape{}) == 1,
+ "Use this function to figure out TMA for each mode individually.");
+
+ if constexpr (cute::size(UnimodalClusterShape{}) == 1) {
+ return cute::SM90_TMA_LOAD{};
+ }
+ else {
+ return cute::SM90_TMA_LOAD_MULTICAST{};
+ }
+}
+
+// Generates the most efficient possible TiledCopy with cp.async copy atom given a set of parameters.
+template
+constexpr auto
+make_cp_async_gmem_tiled_copy() {
+ using namespace cute;
+
+ using AlignmentType = cute::uint_byte_t(sizeof(Element)) * Alignment>;
+ constexpr int TileSizeMN = cute::size(TileMN{});
+ constexpr int TileSizeK = cute::size(TileK{});
+
+ // Maximize the number of threads along the gmem major mode to promote coalesced reads
+ // While making sure our thread layout tiles the threadblock tile evenly
+
+ if constexpr (cutlass::gemm::detail::is_k_major()) {
+ // K major thread layout for K major gmem
+ constexpr int threads_major = (ThreadCount >= TileSizeK / Alignment) ? (TileSizeK / Alignment) : ThreadCount;
+ constexpr int threads_minor = ThreadCount / threads_major;
+ static_assert(threads_major > 0);
+ static_assert(ThreadCount % threads_major == 0);
+ static_assert(threads_minor == 0 || (TileSizeMN % threads_minor == 0));
+ return make_tiled_copy(
+ Copy_Atom, Element>{},
+ Layout,Int>,
+ Stride, _1>>{},
+ Layout>>{});
+ }
+ else if constexpr (cutlass::gemm::detail::is_mn_major()) {
+ // MN major thread layout for MN major gmem
+ constexpr int threads_major = (ThreadCount >= TileSizeMN / Alignment) ? (TileSizeMN / Alignment) : ThreadCount;
+ constexpr int threads_minor = ThreadCount / threads_major;
+ static_assert(threads_major > 0);
+ static_assert(ThreadCount % threads_major == 0);
+ static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0));
+ return make_tiled_copy(
+ Copy_Atom, Element>{},
+ Layout,Int>,
+ Stride< _1,Int>>{},
+ Layout,_1>>{});
+ }
+ else {
+ static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder.");
+ }
+}
+
+// Helper for SS GMMA smem selection that considers a tensor TileShape:
+// (BLK_MN, BLK_K)
+// or hierarchically
+// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...))
+// and returns the optimal GMMA::Layout that fits BLK_MN0 and BLK_K0
+template
+constexpr auto
+rs_smem_selector() {
+ using namespace cute;
+
+ auto BLK_MN0 = size<0>(BLK_MN{});
+ auto BLK_K0 = size<0>(BLK_K{});
+
+ static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8.");
+ static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8.");
+ if constexpr (major == GMMA::Major::MN) {
+ if constexpr (sizeof(ElementType) == 4){
+ if constexpr (is_ws_transposed_B) {
+ // only optimized transpositionB(SW32 and SW128 for tf32) can be used, but prefer SW32 due to free bank conflict
+ if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW32_Atom{};
+ }
+ else {
+ static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0,
+ "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_SW32_Atom{})");
+ }
+ }
+ else {
+ // Fall into SW32 due to free bank conflict
+ if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW32_Atom{};
+ }
+ else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) {
+ return GMMA::Layout_MN_INTER_Atom{};
+ }
+ else {
+ static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0,
+ "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})");
+ }
+ }
+ }
+ // Used for int8, fp8, fp16 and bf16 I/O kernels
+ else if constexpr (sizeof(ElementType) == 1 || sizeof(ElementType) == 2) {
+ if constexpr (sizeof(ElementType) == 1 && is_ws_transposed_B) {
+ // Only optimized transpositionB (SW32 for int8 and fp8) can be used
+ if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW128_Atom{};
+ }
+ else {
+ static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0,
+ "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_128_Atom{})");
+ }
+ }
+ else {
+ if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW128_Atom{};
+ }
+ else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW64_Atom{};
+ }
+ else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW32_Atom{};
+ }
+ else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) {
+ return GMMA::Layout_MN_INTER_Atom{};
+ }
+ else {
+ static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0,
+ "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})");
+ }
+ }
+ }
+ else {
+ static_assert(cutlass::detail::dependent_false, "Smem selector does not support this element type");
+ }
+ }
+ else if constexpr (major == GMMA::Major::K) {
+ if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) {
+ return GMMA::Layout_K_SW128_Atom{};
+ }
+ else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) {
+ return GMMA::Layout_K_SW64_Atom{};
+ }
+ else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) {
+ return GMMA::Layout_K_SW32_Atom{};
+ }
+ else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) {
+ return GMMA::Layout_K_INTER_Atom{};
+ }
+ else {
+ static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0,
+ "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})");
+ }
+ }
+}
+
+// Helper for SS GMMA smem selection that considers a tensor TileShape:
+// (BLK_MN, BLK_K)
+// or hierarchically
+// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...))
+// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0
+template
+CUTE_HOST_DEVICE constexpr
+auto
+ss_smem_selector()
+{
+ using namespace cute;
+
+ auto BLK_MN0 = size<0>(BLK_MN{});
+ auto BLK_K0 = size<0>(BLK_K{});
+
+ static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8.");
+ static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8.");
+
+ if constexpr (major == GMMA::Major::MN) {
+ if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW128_Atom{};
+ }
+ else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW64_Atom{};
+ }
+ else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) {
+ return GMMA::Layout_MN_SW32_Atom{};
+ }
+ else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) {
+ return GMMA::Layout_MN_INTER_Atom{};
+ }
+ else {
+ static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0,
+ "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})");
+ }
+ }
+ else if constexpr (major == GMMA::Major::K) {
+ if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) {
+ return GMMA::Layout_K_SW128_Atom{};
+ }
+ else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) {
+ return GMMA::Layout_K_SW64_Atom{};
+ }
+ else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) {
+ return GMMA::Layout_K_SW32_Atom{};
+ }
+ else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) {
+ return GMMA::Layout_K_INTER_Atom{};
+ }
+ else {
+ static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0,
+ "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})");
+ }
+ }
+}
+
+template
+constexpr bool
+is_input_size_two_bytes() {
+ return (sizeof(ElementA) == 2 && sizeof(ElementB) == 2);
+}
+
+template
+constexpr bool
+is_input_fp8() {
+ return ((cute::is_same_v || cute::is_same_v) &&
+ (cute::is_same_v || cute::is_same_v));
+}
+
+// We need to handle the tuples in this function since it is used in SFINAE dispatch in the CollectiveBuilder.
+// At that point, it is not guaranteed that the tuples have been split out into the required parts.
+template
+constexpr bool
+is_use_rmem_A() {
+
+ using ElementA = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementA>;
+ using ElementB = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementB>;
+
+ constexpr bool IsABDifferentWidth = cute::sizeof_bits_v != cute::sizeof_bits_v;
+ constexpr bool HasScales = cute::is_tuple::value ^ cute::is_tuple