Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add CUTLASS sparse support, heuristics, and torch operators #10340

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")

FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.5.1
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
# GIT_SHALLOW FALSE
)
FetchContent_MakeAvailable(cutlass)

Expand All @@ -226,7 +226,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_compressor.cu")

set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
Expand Down Expand Up @@ -256,11 +258,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()

#
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
# For Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/sparse/cutlass/sparse_compressor.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand All @@ -269,12 +274,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"later if you intend on running FP8 quantized models or sparse on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
message(STATUS "Not building cutlass_c3x as no compatible archs found "
"in CUDA target architectures")
endif()

Expand Down Expand Up @@ -399,6 +404,9 @@ define_gpu_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

# include(nm_cutlass_c.cmake)
# build_nm_cutlass_c()

#
# _moe_C extension
#
Expand Down
Loading