From cb2c6886484142cbb35fc95dba2d0158db9cc7bd Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Fri, 12 Jul 2024 12:49:39 +0800 Subject: [PATCH 1/6] Update doc for MUSA Signed-off-by: Xiaodong Ye --- README.md | 1 + docs/build.md | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/README.md b/README.md index 058919068b872..54267b5fd3bfc 100644 --- a/README.md +++ b/README.md @@ -405,6 +405,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md) | [BLAS](./docs/build.md#blas-build) | All | | [BLIS](./docs/backend/BLIS.md) | All | | [SYCL](./docs/backend/SYCL.md) | Intel and Nvidia GPU | +| [MUSA](./docs/build.md#musa) | Moore Threads GPU | | [CUDA](./docs/build.md#cuda) | Nvidia GPU | | [hipBLAS](./docs/build.md#hipblas) | AMD GPU | | [Vulkan](./docs/build.md#vulkan) | GPU | diff --git a/docs/build.md b/docs/build.md index d70f72f4c7b82..ea58bdd584a9c 100644 --- a/docs/build.md +++ b/docs/build.md @@ -181,6 +181,19 @@ The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/c | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | | GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. | +### MUSA + +- Using `make`: + ```bash + make GGML_MUSA=1 + ``` +- Using `CMake`: + + ```bash + cmake -B build -DGGML_MUSA=ON + cmake --build build --config Release + ``` + ### hipBLAS This provides BLAS acceleration on HIP-supported AMD GPUs. From 779c920b888945bf29e10cc79fb31706506dd654 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Sun, 14 Jul 2024 15:59:20 +0800 Subject: [PATCH 2/6] Add GGML_MUSA in Makefile Signed-off-by: Xiaodong Ye --- Makefile | 55 ++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index 668b38b99c312..feadf7f183900 100644 --- a/Makefile +++ b/Makefile @@ -526,10 +526,21 @@ ifndef GGML_NO_ACCELERATE endif endif # GGML_NO_ACCELERATE +ifdef GGML_MUSA + CC := clang + CXX := clang++ + GGML_CUDA := 1 + MK_CPPFLAGS += -DGGML_USE_MUSA +endif + ifndef GGML_NO_OPENMP MK_CPPFLAGS += -DGGML_USE_OPENMP MK_CFLAGS += -fopenmp MK_CXXFLAGS += -fopenmp + ifdef GGML_MUSA + MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp + MK_LDFLAGS += -L/usr/lib/llvm-10/lib + endif # GGML_MUSA endif # GGML_NO_OPENMP ifdef GGML_OPENBLAS @@ -574,15 +585,27 @@ else endif # GGML_CUDA_FA_ALL_QUANTS ifdef GGML_CUDA - ifneq ('', '$(wildcard /opt/cuda)') - CUDA_PATH ?= /opt/cuda + ifdef GGML_MUSA + ifneq ('', '$(wildcard /opt/musa)') + CUDA_PATH ?= /opt/musa + else + CUDA_PATH ?= /usr/local/musa + endif + + MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include + MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64 + MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22 else - CUDA_PATH ?= /usr/local/cuda - endif + ifneq ('', '$(wildcard /opt/cuda)') + CUDA_PATH ?= /opt/cuda + else + CUDA_PATH ?= /usr/local/cuda + endif - MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS - MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib - MK_NVCCFLAGS += -use_fast_math + MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS + MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib + MK_NVCCFLAGS += -use_fast_math + endif # GGML_MUSA OBJ_GGML += ggml/src/ggml-cuda.o OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) @@ -592,9 +615,11 @@ ifdef LLAMA_FATAL_WARNINGS MK_NVCCFLAGS += -Werror all-warnings endif # LLAMA_FATAL_WARNINGS +ifndef GGML_MUSA ifndef JETSON_EOL_MODULE_DETECT MK_NVCCFLAGS += --forward-unknown-to-host-compiler endif # JETSON_EOL_MODULE_DETECT +endif # GGML_MUSA ifdef LLAMA_DEBUG MK_NVCCFLAGS += -lineinfo @@ -607,8 +632,12 @@ endif # GGML_CUDA_DEBUG ifdef GGML_CUDA_NVCC NVCC = $(CCACHE) $(GGML_CUDA_NVCC) else - NVCC = $(CCACHE) nvcc -endif #GGML_CUDA_NVCC + ifdef GGML_MUSA + NVCC = $(CCACHE) mcc + else + NVCC = $(CCACHE) nvcc + endif # GGML_MUSA +endif # GGML_CUDA_NVCC ifdef CUDA_DOCKER_ARCH MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH) @@ -679,9 +708,15 @@ define NVCC_COMPILE $(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ endef # NVCC_COMPILE else + ifdef GGML_MUSA +define NVCC_COMPILE + $(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -c $< -o $@ +endef # NVCC_COMPILE + else define NVCC_COMPILE $(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ endef # NVCC_COMPILE + endif # GGML_MUSA endif # JETSON_EOL_MODULE_DETECT ggml/src/ggml-cuda/%.o: \ @@ -907,6 +942,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1)) ifdef GGML_CUDA $(info I NVCC: $(shell $(NVCC) --version | tail -n 1)) CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])') +ifndef GGML_MUSA ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1) ifndef CUDA_DOCKER_ARCH @@ -916,6 +952,7 @@ endif # CUDA_POWER_ARCH endif # CUDA_DOCKER_ARCH endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1) +endif # GGML_MUSA endif # GGML_CUDA $(info ) From f0856840388b4cabe73c26dde37647476e924916 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Sun, 14 Jul 2024 15:59:32 +0800 Subject: [PATCH 3/6] Add GGML_MUSA in CMake Signed-off-by: Xiaodong Ye --- ggml/CMakeLists.txt | 1 + ggml/src/CMakeLists.txt | 62 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 649ac3dcc4f63..eee59ae31a618 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -107,6 +107,7 @@ set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF) option(GGML_CUDA "ggml: use CUDA" OFF) +option(GGML_MUSA "ggml: use MUSA" OFF) option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index c5ee7e4255ee5..979c124c5faf0 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -139,6 +139,17 @@ if (GGML_METAL) ) endif() +if (GGML_MUSA) + set(CMAKE_C_COMPILER clang) + set(CMAKE_C_EXTENSIONS OFF) + set(CMAKE_CXX_COMPILER clang++) + set(CMAKE_CXX_EXTENSIONS OFF) + + set(GGML_CUDA ON) + + list(APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA) +endif() + if (GGML_OPENMP) find_package(OpenMP) if (OpenMP_FOUND) @@ -147,6 +158,11 @@ if (GGML_OPENMP) add_compile_definitions(GGML_USE_OPENMP) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + + if (GGML_MUSA) + set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp") + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so") + endif() else() message(WARNING "OpenMP not found") endif() @@ -249,7 +265,13 @@ endif() if (GGML_CUDA) cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES - find_package(CUDAToolkit) + if (GGML_MUSA) + list(APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/") + find_package(MUSAToolkit) + set(CUDAToolkit_FOUND ${MUSAToolkit_FOUND}) + else() + find_package(CUDAToolkit) + endif() if (CUDAToolkit_FOUND) message(STATUS "CUDA found") @@ -268,7 +290,11 @@ if (GGML_CUDA) endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") - enable_language(CUDA) + if (GGML_MUSA) + set(CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE}) + else() + enable_language(CUDA) + endif() file(GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h") @@ -332,21 +358,40 @@ if (GGML_CUDA) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() + if (GGML_MUSA) + set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX) + foreach(SOURCE ${GGML_SOURCES_CUDA}) + set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22") + endforeach() + endif() + if (GGML_STATIC) if (WIN32) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) else () - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + if (GGML_MUSA) + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static) + else() + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + endif() endif() else() - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + if (GGML_MUSA) + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas) + else() + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() endif() if (GGML_CUDA_NO_VMM) # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so) else() - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ... + if (GGML_MUSA) + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ... + else() + set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ... + endif() endif() else() message(WARNING "CUDA not found") @@ -757,8 +802,10 @@ function(get_flags CCID CCVER) set(C_FLAGS -Wdouble-promotion) set(CXX_FLAGS -Wno-array-bounds) - if (CCVER VERSION_GREATER_EQUAL 7.1.0) - list(APPEND CXX_FLAGS -Wno-format-truncation) + if (NOT GGML_MUSA) + if (CCVER VERSION_GREATER_EQUAL 7.1.0) + list(APPEND CXX_FLAGS -Wno-format-truncation) + endif() endif() if (CCVER VERSION_GREATER_EQUAL 8.1.0) list(APPEND CXX_FLAGS -Wextra-semi) @@ -1163,6 +1210,7 @@ endif() target_compile_definitions(ggml PUBLIC ${GGML_CDEF_PUBLIC}) target_include_directories(ggml PUBLIC ../include) target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES}) +target_link_directories(ggml PRIVATE ${GGML_EXTRA_LIBDIRS}) target_compile_features (ggml PRIVATE c_std_11) # don't bump target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS}) From 27396eac464f32b3b31a14569e732dacb1a3be2f Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Tue, 16 Jul 2024 16:31:30 +0800 Subject: [PATCH 4/6] CUDA => MUSA Signed-off-by: Xiaodong Ye --- ggml/include/ggml-cuda.h | 3 + ggml/src/ggml-common.h | 6 +- ggml/src/ggml-cuda.cu | 22 ++-- ggml/src/ggml-cuda/common.cuh | 217 +++++++++++++++++++++++++++++++++- 4 files changed, 234 insertions(+), 14 deletions(-) diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index d7903c666cebf..71bb6dcf07975 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -6,6 +6,9 @@ #ifdef GGML_USE_HIPBLAS #define GGML_CUDA_NAME "ROCm" #define GGML_CUBLAS_NAME "hipBLAS" +#elif defined(GGML_USE_MUSA) +#define GGML_CUDA_NAME "MUSA" +#define GGML_CUBLAS_NAME "muBLAS" #else #define GGML_CUDA_NAME "CUDA" #define GGML_CUBLAS_NAME "cuBLAS" diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index fafd5fa7ae000..e40057632fc5a 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -19,7 +19,11 @@ typedef half2 ggml_half2; #define GGML_COMMON_DECL #elif defined(GGML_COMMON_DECL_CUDA) +#if defined(GGML_COMMON_DECL_MUSA) +#include +#else #include +#endif #include typedef half ggml_half; @@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_ #define GGML_TABLE_END() }; #define GGML_COMMON_IMPL -#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) +#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA) #include #define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = { diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ed784ea1c2024..919420eff2581 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -167,7 +167,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -179,7 +179,7 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; @@ -315,7 +315,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -409,14 +409,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); } }; -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } -#endif +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) return std::unique_ptr(new ggml_cuda_pool_leg(device)); } @@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices cudaMemcpy3DPeerParms p = {}; p.dstDevice = dstDevice; @@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync( GGML_UNUSED(dstDevice); GGML_UNUSED(srcDevice); return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) } static void ggml_cuda_op_mul_mat( @@ -1821,6 +1821,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else +#ifdef GGML_USE_MUSA + GGML_ASSERT(false); +#else // !GGML_USE_MUSA if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx @@ -1863,6 +1866,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } +#endif // GGML_USE_MUSA #endif if (dst->op_params[0] == GGML_PREC_DEFAULT) { @@ -3019,7 +3023,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size return false; } -#if CUDART_VERSION >= 11100 +#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4ff06b8719d37..3554c9202bad0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -12,6 +12,10 @@ #else #define GGML_COMMON_DECL_CUDA #define GGML_COMMON_IMPL_CUDA +#if defined(GGML_USE_MUSA) +#define GGML_COMMON_DECL_MUSA +#define GGML_COMMON_IMPL_MUSA +#endif #endif #include "ggml-common.h" @@ -114,6 +118,151 @@ #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED +#elif defined(GGML_USE_MUSA) +#include +#include +#include +#include +// XXX: Keep the following order the same as hipBLAS +// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F +// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F +#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F +#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +// #define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_32F MUSA_R_32F +// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +// #define cublasComputeType_t mublasComputeType_t +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasGemmEx mublasGemmEx +#define cublasGemmBatchedEx mublasGemmBatchedEx +#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx +#define cublasHandle_t mublasHandle_t +// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasSgemm mublasSgemm +#define cublasStatus_t mublasStatus_t +#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6 +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceProp musaDeviceProp +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaError_t musaError_t +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventRecord musaEventRecord +#define cudaEventSynchronize musaEventSynchronize +#define cudaEvent_t musaEvent_t +#define cudaEventDestroy musaEventDestroy +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaHostRegister musaHostRegister +#define cudaHostRegisterPortable musaHostRegisterPortable +#define cudaHostRegisterReadOnly musaHostRegisterReadOnly +#define cudaHostUnregister musaHostUnregister +#define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaMemcpy musaMemcpy +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMemsetAsync musaMemsetAsync +#define cudaMemGetInfo musaMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaSetDevice musaSetDevice +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamFireAndForget musaStreamFireAndForget +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamWaitEvent musaStreamWaitEvent +#define cudaStream_t musaStream_t +#define cudaSuccess musaSuccess + +// XXX: Other CUDA => MUSA mapping +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr +#define CUmemAccessDesc MUmemAccessDesc +#define CUmemAllocationProp MUmemAllocationProp +#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +#define cuMemAddressFree muMemAddressFree +#define cuMemAddressReserve muMemAddressReserve +#define cuMemCreate muMemCreate +#define cuMemGetAllocationGranularity muMemGetAllocationGranularity +#define cuMemMap muMemMap +#define cuMemRelease muMemRelease +#define cuMemSetAccess muMemSetAccess +#define cuMemUnmap muMemUnmap +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms +#define make_cudaExtent make_musaExtent +#define make_cudaPitchedPtr make_musaPitchedPtr + +// XXX: USE_CUDA_GRAPH +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUresult MUresult +#define cuGetErrorString muGetErrorString +#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure +#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction +#define cudaGraphDestroy musaGraphDestroy +#define cudaGraphExecDestroy musaGraphExecDestroy +#define cudaGraphExec_t musaGraphExec_t +#define cudaGraphExecUpdate musaGraphExecUpdate +#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult +#define cudaGraphGetNodes musaGraphGetNodes +#define cudaGraphInstantiate musaGraphInstantiate +#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams +#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams +#define cudaGraphLaunch musaGraphLaunch +#define cudaGraphNodeGetType musaGraphNodeGetType +#define cudaGraphNode_t musaGraphNode_t +#define cudaGraphNodeType musaGraphNodeType +#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel +#define cudaGraph_t musaGraph_t +#define cudaKernelNodeParams musaKernelNodeParams +#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed +#define cudaStreamEndCapture musaStreamEndCapture + +// XXX: cuBLAS => muBLAS mapping +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define cublasComputeType_t cudaDataType_t + +// XXX: Clang builtins mapping +#define __vsubss4 __vsubss4_musa +#define __vsub4 __vsub4_musa +#define __vcmpeq4 __vcmpeq4_musa +#define __vcmpne4 __vcmpne4_musa #else #include #include @@ -168,9 +317,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) -#if CUDART_VERSION >= 12000 - static const char * cublas_get_error_str(const cublasStatus_t err) { +#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) + static const char * cublas_get_error_str(const mublasStatus_t err) { +#ifndef GGML_USE_MUSA return cublasGetStatusString(err); +#else + return mublasStatus_to_string(err); +#endif // GGML_USE_MUSA } #else static const char * cublas_get_error_str(const cublasStatus_t err) { @@ -200,7 +353,7 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif -#if CUDART_VERSION >= 11100 +#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else #define GGML_CUDA_ASSUME(x) @@ -214,6 +367,62 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif //GGML_CUDA_F16 +#if defined(GGML_USE_MUSA) +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4_musa(const int a, const int b) { + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); +#if __has_builtin(__builtin_elementwise_sub_sat) + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); + if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); + c[i] = tmp; + } + return reinterpret_cast(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) +} + +static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) { + return __vsubss4_musa(a, b); +} + +static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0xff : 0x00; + } + return c; +} + +static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0x00 : 0xff; + } + return c; +} +#endif // defined(GGML_USE_MUSA) + #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -455,7 +664,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); return mask_low | mask_high; } -#endif // CUDART_VERSION < 12000 +#endif // CUDART_VERSION < CUDART_HMASK static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) From f4a9303e49fec43017f5ec9559abfd9dddcc8466 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Tue, 23 Jul 2024 19:37:38 +0800 Subject: [PATCH 5/6] MUSA adds support for __vsubss4 Signed-off-by: Xiaodong Ye --- ggml/src/ggml-cuda/common.cuh | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 3554c9202bad0..658c72fdfbac1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -259,7 +259,6 @@ #define cublasComputeType_t cudaDataType_t // XXX: Clang builtins mapping -#define __vsubss4 __vsubss4_musa #define __vsub4 __vsub4_musa #define __vcmpeq4 __vcmpeq4_musa #define __vcmpne4 __vcmpne4_musa @@ -372,30 +371,10 @@ typedef float2 dfloat2; #define __has_builtin(x) 0 #endif -typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); -static __device__ __forceinline__ int __vsubss4_musa(const int a, const int b) { - const int8x4_t va = reinterpret_cast(a); - const int8x4_t vb = reinterpret_cast(b); -#if __has_builtin(__builtin_elementwise_sub_sat) - const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); - return reinterpret_cast(c); -#else - int8x4_t c; - int16_t tmp; -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp = va[i] - vb[i]; - if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); - if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); - c[i] = tmp; - } - return reinterpret_cast(c); -#endif // __has_builtin(__builtin_elementwise_sub_sat) -} static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) { - return __vsubss4_musa(a, b); + return __vsubss4(a, b); } static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) { From 5afc5cf842ab010730b73a7f249b1be1f34470a8 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Thu, 25 Jul 2024 09:04:28 +0800 Subject: [PATCH 6/6] Fix CI build failure Signed-off-by: Xiaodong Ye --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 658c72fdfbac1..60fc0c970c2e5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -317,7 +317,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) - static const char * cublas_get_error_str(const mublasStatus_t err) { + static const char * cublas_get_error_str(const cublasStatus_t err) { #ifndef GGML_USE_MUSA return cublasGetStatusString(err); #else